Skip to content

Commit b6b69e7

Browse files
Implemented tpu_vm_create_startup_script sample, created test
1 parent b804cc8 commit b6b69e7

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_create_startup_script]
20+
import com.google.api.gax.longrunning.OperationTimedPollAlgorithm;
21+
import com.google.api.gax.retrying.RetrySettings;
22+
import com.google.cloud.tpu.v2.CreateNodeRequest;
23+
import com.google.cloud.tpu.v2.Node;
24+
import com.google.cloud.tpu.v2.TpuClient;
25+
import com.google.cloud.tpu.v2.TpuSettings;
26+
import java.io.IOException;
27+
import java.util.HashMap;
28+
import java.util.Map;
29+
import java.util.concurrent.ExecutionException;
30+
import org.threeten.bp.Duration;
31+
32+
public class CreateTpuVmWithStartupScript {
33+
public static void main(String[] args)
34+
throws IOException, ExecutionException, InterruptedException {
35+
// TODO(developer): Replace these variables before running the sample.
36+
// Project ID or project number of the Google Cloud project you want to create a node.
37+
String projectId = "YOUR_PROJECT_ID";
38+
// The zone in which to create the TPU.
39+
// For more information about supported TPU types for specific zones,
40+
// see https://cloud.google.com/tpu/docs/regions-zones
41+
String zone = "europe-west4-a";
42+
// The name for your TPU.
43+
String nodeName = "YOUR_TPY_NAME";
44+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
45+
// For more information about supported accelerator types for each TPU version,
46+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
47+
String acceleratorType = "v2-8";
48+
// Software version that specifies the version of the TPU runtime to install.
49+
// For more information, see https://cloud.google.com/tpu/docs/runtimes
50+
String tpuSoftwareVersion = "tpu-vm-tf-2.14.1";
51+
52+
createTpuVmWithStartupScript(projectId, zone, nodeName, acceleratorType, tpuSoftwareVersion);
53+
}
54+
55+
// Create a TPU VM with a startup script.
56+
public static Node createTpuVmWithStartupScript(String projectId, String zone,
57+
String nodeName, String acceleratorType, String tpuSoftwareVersion)
58+
throws IOException, ExecutionException, InterruptedException {
59+
// With these settings the client library handles the Operation's polling mechanism
60+
// and prevent CancellationException error
61+
TpuSettings.Builder clientSettings =
62+
TpuSettings.newBuilder();
63+
clientSettings
64+
.createNodeOperationSettings()
65+
.setPollingAlgorithm(
66+
OperationTimedPollAlgorithm.create(
67+
RetrySettings.newBuilder()
68+
.setInitialRetryDelay(Duration.ofMillis(5000L))
69+
.setRetryDelayMultiplier(1.5)
70+
.setMaxRetryDelay(Duration.ofMillis(45000L))
71+
.setInitialRpcTimeout(Duration.ZERO)
72+
.setRpcTimeoutMultiplier(1.0)
73+
.setMaxRpcTimeout(Duration.ZERO)
74+
.setTotalTimeout(Duration.ofHours(24L))
75+
.build()));
76+
77+
// Initialize client that will be used to send requests. This client only needs to be created
78+
// once, and can be reused for multiple requests.
79+
try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) {
80+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
81+
82+
// Create metadata map
83+
Map<String, String> metadata = new HashMap<>();
84+
metadata.put("startup-script", "your-script-here"); // Script content here
85+
86+
Node tpuVm =
87+
Node.newBuilder()
88+
.setName(nodeName)
89+
.setAcceleratorType(acceleratorType)
90+
.setRuntimeVersion(tpuSoftwareVersion)
91+
.putAllLabels(metadata)
92+
.build();
93+
94+
CreateNodeRequest request =
95+
CreateNodeRequest.newBuilder()
96+
.setParent(parent)
97+
.setNodeId(nodeName)
98+
.setNode(tpuVm)
99+
.build();
100+
101+
Node response = tpuClient.createNodeAsync(request).get();
102+
103+
System.out.printf("TPU VM created: %s\n", response.getName());
104+
return response;
105+
}
106+
}
107+
}
108+
//[END tpu_vm_create_startup_script]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static com.google.common.truth.Truth.assertWithMessage;
21+
import static org.junit.Assert.assertNotNull;
22+
23+
import com.google.api.gax.rpc.NotFoundException;
24+
import com.google.cloud.tpu.v2.Node;
25+
import java.io.IOException;
26+
import java.util.UUID;
27+
import java.util.concurrent.ExecutionException;
28+
import java.util.concurrent.TimeUnit;
29+
import org.junit.Assert;
30+
import org.junit.jupiter.api.AfterAll;
31+
import org.junit.jupiter.api.Assertions;
32+
import org.junit.jupiter.api.BeforeAll;
33+
import org.junit.jupiter.api.Test;
34+
import org.junit.jupiter.api.Timeout;
35+
import org.junit.runner.RunWith;
36+
import org.junit.runners.JUnit4;
37+
38+
@RunWith(JUnit4.class)
39+
@Timeout(value = 25, unit = TimeUnit.MINUTES)
40+
public class CreateTpuVmWithStartupScriptIT {
41+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
42+
private static final String ZONE = "europe-west4-a";
43+
static String javaVersion = System.getProperty("java.version").substring(0, 2);
44+
private static final String NODE_NAME = "test-tpu-with-script-" + javaVersion + "-"
45+
+ UUID.randomUUID().toString().substring(0, 8);
46+
private static final String TPU_TYPE = "v2-8";
47+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
48+
49+
public static void requireEnvVar(String envVarName) {
50+
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
51+
.that(System.getenv(envVarName)).isNotEmpty();
52+
}
53+
54+
@BeforeAll
55+
public static void setUp()
56+
throws IOException, ExecutionException, InterruptedException {
57+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
58+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
59+
60+
// Cleanup existing stale resources.
61+
Util.cleanUpExistingTpu("test-tpu-with-script-" + javaVersion, PROJECT_ID, ZONE);
62+
}
63+
64+
@AfterAll
65+
public static void cleanup() throws Exception {
66+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
67+
68+
// Test that TPUs is deleted
69+
Assertions.assertThrows(
70+
NotFoundException.class,
71+
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
72+
}
73+
74+
@Test
75+
public void testCreateTpuVmWithStartupScript()
76+
throws IOException, ExecutionException, InterruptedException {
77+
Node node = CreateTpuVmWithStartupScript.createTpuVmWithStartupScript(
78+
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
79+
80+
assertNotNull(node);
81+
assertThat(node.getName().equals(NODE_NAME));
82+
Assert.assertTrue(node.containsLabels("startup-script"));
83+
Assert.assertTrue(node.getLabelsMap().containsValue("your-script-here"));
84+
85+
}
86+
}

0 commit comments

Comments
 (0)