Skip to content

Commit af6e29e

Browse files
Implemented tpu_vm_create_spot sample, created test
1 parent ec13f4d commit af6e29e

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_create_spot]
20+
import com.google.api.gax.longrunning.OperationTimedPollAlgorithm;
21+
import com.google.api.gax.retrying.RetrySettings;
22+
import com.google.cloud.tpu.v2.CreateNodeRequest;
23+
import com.google.cloud.tpu.v2.Node;
24+
import com.google.cloud.tpu.v2.SchedulingConfig;
25+
import com.google.cloud.tpu.v2.TpuClient;
26+
import com.google.cloud.tpu.v2.TpuSettings;
27+
import java.io.IOException;
28+
import java.util.concurrent.ExecutionException;
29+
import org.threeten.bp.Duration;
30+
31+
public class CreateSpotTpuVm {
32+
public static void main(String[] args)
33+
throws IOException, ExecutionException, InterruptedException {
34+
// TODO(developer): Replace these variables before running the sample.
35+
// Project ID or project number of the Google Cloud project you want to create a node.
36+
String projectId = "YOUR_PROJECT_ID";
37+
// The zone in which to create the TPU.
38+
// For more information about supported TPU types for specific zones,
39+
// see https://cloud.google.com/tpu/docs/regions-zones
40+
String zone = "europe-west4-a";
41+
// The name for your TPU.
42+
String nodeName = "YOUR_TPY_NAME";
43+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
44+
// For more information about supported accelerator types for each TPU version,
45+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
46+
String tpuType = "v2-8";
47+
// Software version that specifies the version of the TPU runtime to install.
48+
// For more information see https://cloud.google.com/tpu/docs/runtimes
49+
String tpuSoftwareVersion = "tpu-vm-tf-2.14.1";
50+
51+
createSpotTpuVm(projectId, zone, nodeName, tpuType, tpuSoftwareVersion);
52+
}
53+
54+
// Creates a preemptible TPU VM with the specified name, zone, accelerator type, and version.
55+
public static Node createSpotTpuVm(
56+
String projectId, String zone, String nodeName, String tpuType, String tpuSoftwareVersion)
57+
throws IOException, ExecutionException, InterruptedException {
58+
// With these settings the client library handles the Operation's polling mechanism
59+
// and prevent CancellationException error
60+
TpuSettings.Builder clientSettings =
61+
TpuSettings.newBuilder();
62+
clientSettings
63+
.createNodeOperationSettings()
64+
.setPollingAlgorithm(
65+
OperationTimedPollAlgorithm.create(
66+
RetrySettings.newBuilder()
67+
.setInitialRetryDelay(Duration.ofMillis(5000L))
68+
.setRetryDelayMultiplier(1.5)
69+
.setMaxRetryDelay(Duration.ofMillis(45000L))
70+
.setInitialRpcTimeout(Duration.ZERO)
71+
.setRpcTimeoutMultiplier(1.0)
72+
.setMaxRpcTimeout(Duration.ZERO)
73+
.setTotalTimeout(Duration.ofHours(24L))
74+
.build()));
75+
76+
// Initialize client that will be used to send requests. This client only needs to be created
77+
// once, and can be reused for multiple requests.
78+
try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) {
79+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
80+
81+
SchedulingConfig schedulingConfig = SchedulingConfig.newBuilder()
82+
.setPreemptible(true)
83+
.build();
84+
85+
Node tpuVm = Node.newBuilder()
86+
.setName(nodeName)
87+
.setAcceleratorType(tpuType)
88+
.setRuntimeVersion(tpuSoftwareVersion)
89+
.setSchedulingConfig(schedulingConfig)
90+
.build();
91+
92+
CreateNodeRequest request = CreateNodeRequest.newBuilder()
93+
.setParent(parent)
94+
.setNodeId(nodeName)
95+
.setNode(tpuVm)
96+
.build();
97+
98+
Node response = tpuClient.createNodeAsync(request).get();
99+
System.out.printf("Spot TPU VM created: %s\n", response.getName());
100+
return response;
101+
}
102+
}
103+
}
104+
//[END tpu_vm_create_spot]
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static com.google.common.truth.Truth.assertWithMessage;
21+
import static org.junit.Assert.assertTrue;
22+
23+
import com.google.api.gax.rpc.NotFoundException;
24+
import com.google.cloud.tpu.v2.Node;
25+
import java.io.IOException;
26+
import java.util.UUID;
27+
import java.util.concurrent.ExecutionException;
28+
import java.util.concurrent.TimeUnit;
29+
import org.junit.jupiter.api.AfterAll;
30+
import org.junit.jupiter.api.Assertions;
31+
import org.junit.jupiter.api.BeforeAll;
32+
import org.junit.jupiter.api.Test;
33+
import org.junit.jupiter.api.Timeout;
34+
import org.junit.runner.RunWith;
35+
import org.junit.runners.JUnit4;
36+
37+
@RunWith(JUnit4.class)
38+
@Timeout(value = 6, unit = TimeUnit.MINUTES)
39+
public class CreateSpotTpuVmIT {
40+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
41+
private static final String ZONE = "us-central1-a";
42+
static String javaVersion = System.getProperty("java.version").substring(0, 2);
43+
private static final String NODE_NAME = "test-spot-tpu-" + javaVersion + "-"
44+
+ UUID.randomUUID().toString().substring(0, 8);
45+
private static final String TPU_TYPE = "v3-8";
46+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
47+
private static final String NODE_PATH_NAME =
48+
String.format("projects/%s/locations/%s/nodes/%s", PROJECT_ID, ZONE, NODE_NAME);
49+
50+
public static void requireEnvVar(String envVarName) {
51+
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
52+
.that(System.getenv(envVarName)).isNotEmpty();
53+
}
54+
55+
@BeforeAll
56+
public static void setUp()
57+
throws IOException, ExecutionException, InterruptedException {
58+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
59+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
60+
61+
// Cleanup existing stale resources.
62+
Util.cleanUpExistingTpu("test-spot-tpu-" + javaVersion, PROJECT_ID, ZONE);
63+
}
64+
65+
@AfterAll
66+
public static void cleanup() throws Exception {
67+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
68+
69+
// Test that TPUs is deleted
70+
Assertions.assertThrows(
71+
NotFoundException.class,
72+
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
73+
}
74+
75+
@Test
76+
public void testCreateSpotTpuVm() throws IOException, ExecutionException, InterruptedException {
77+
Node node = CreateSpotTpuVm.createSpotTpuVm(
78+
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
79+
80+
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
81+
assertTrue(node.getSchedulingConfig().getPreemptible());
82+
83+
}
84+
}

0 commit comments

Comments
 (0)