Skip to content

Commit d8e2887

Browse files
Implemented tpu_vm_create_topology sample, created test
1 parent ec13f4d commit d8e2887

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_create_topology]
20+
import com.google.api.gax.longrunning.OperationTimedPollAlgorithm;
21+
import com.google.api.gax.retrying.RetrySettings;
22+
import com.google.cloud.tpu.v2.AcceleratorConfig;
23+
import com.google.cloud.tpu.v2.AcceleratorConfig.Type;
24+
import com.google.cloud.tpu.v2.CreateNodeRequest;
25+
import com.google.cloud.tpu.v2.Node;
26+
import com.google.cloud.tpu.v2.TpuClient;
27+
import com.google.cloud.tpu.v2.TpuSettings;
28+
import java.io.IOException;
29+
import java.util.concurrent.ExecutionException;
30+
import org.threeten.bp.Duration;
31+
32+
public class CreateTpuWithTopologyFlag {
33+
34+
public static void main(String[] args)
35+
throws IOException, ExecutionException, InterruptedException {
36+
// TODO(developer): Replace these variables before running the sample.
37+
// Project ID or project number of the Google Cloud project you want to create a node.
38+
String projectId = "YOUR_PROJECT_ID";
39+
// The zone in which to create the TPU.
40+
// For more information about supported TPU types for specific zones,
41+
// see https://cloud.google.com/tpu/docs/regions-zones
42+
String zone = "europe-west4-a";
43+
// The name for your TPU.
44+
String nodeName = "YOUR_TPY_NAME";
45+
// The version of the Cloud TPU you want to create.
46+
// Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7
47+
Type tpuVersion = AcceleratorConfig.Type.V2;
48+
// Software version that specifies the version of the TPU runtime to install.
49+
// For more information, see https://cloud.google.com/tpu/docs/runtimes
50+
String tpuSoftwareVersion = "tpu-vm-tf-2.17.0-pod-pjrt";
51+
// The physical topology of your TPU slice.
52+
// For more information about topology for each TPU version,
53+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
54+
String topology = "2x2";
55+
56+
createTpuWithTopologyFlag(projectId, zone, nodeName, tpuVersion, tpuSoftwareVersion, topology);
57+
}
58+
59+
// Creates a TPU VM with the specified name, zone, version and topology.
60+
public static Node createTpuWithTopologyFlag(String projectId, String zone, String nodeName,
61+
Type tpuVersion, String tpuSoftwareVersion, String topology)
62+
throws IOException, ExecutionException, InterruptedException {
63+
// With these settings the client library handles the Operation's polling mechanism
64+
// and prevent CancellationException error
65+
TpuSettings.Builder clientSettings =
66+
TpuSettings.newBuilder();
67+
clientSettings
68+
.createNodeOperationSettings()
69+
.setPollingAlgorithm(
70+
OperationTimedPollAlgorithm.create(
71+
RetrySettings.newBuilder()
72+
.setInitialRetryDelay(Duration.ofMillis(5000L))
73+
.setRetryDelayMultiplier(1.5)
74+
.setMaxRetryDelay(Duration.ofMillis(45000L))
75+
.setInitialRpcTimeout(Duration.ZERO)
76+
.setRpcTimeoutMultiplier(1.0)
77+
.setMaxRpcTimeout(Duration.ZERO)
78+
.setTotalTimeout(Duration.ofHours(24L))
79+
.build()));
80+
81+
// Initialize client that will be used to send requests. This client only needs to be created
82+
// once, and can be reused for multiple requests.
83+
try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) {
84+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
85+
86+
Node tpuVm =
87+
Node.newBuilder()
88+
.setName(nodeName)
89+
.setAcceleratorConfig(Node.newBuilder()
90+
.getAcceleratorConfigBuilder()
91+
.setType(tpuVersion)
92+
.setTopology(topology)
93+
.build())
94+
.setRuntimeVersion(tpuSoftwareVersion)
95+
.build();
96+
97+
CreateNodeRequest request =
98+
CreateNodeRequest.newBuilder()
99+
.setParent(parent)
100+
.setNodeId(nodeName)
101+
.setNode(tpuVm)
102+
.build();
103+
104+
Node response = tpuClient.createNodeAsync(request).get();
105+
System.out.printf("TPU VM created: %s\n", response.getName());
106+
return response;
107+
}
108+
}
109+
}
110+
//[END tpu_vm_create_topology]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static com.google.common.truth.Truth.assertWithMessage;
21+
import static org.junit.Assert.assertNotNull;
22+
23+
import com.google.api.gax.rpc.NotFoundException;
24+
import com.google.cloud.tpu.v2.AcceleratorConfig;
25+
import com.google.cloud.tpu.v2.AcceleratorConfig.Type;
26+
import com.google.cloud.tpu.v2.Node;
27+
import java.io.IOException;
28+
import java.util.UUID;
29+
import java.util.concurrent.ExecutionException;
30+
import java.util.concurrent.TimeUnit;
31+
import org.junit.jupiter.api.AfterAll;
32+
import org.junit.jupiter.api.Assertions;
33+
import org.junit.jupiter.api.BeforeAll;
34+
import org.junit.jupiter.api.Test;
35+
import org.junit.jupiter.api.Timeout;
36+
import org.junit.runner.RunWith;
37+
import org.junit.runners.JUnit4;
38+
39+
@RunWith(JUnit4.class)
40+
@Timeout(value = 6, unit = TimeUnit.MINUTES)
41+
public class CreateTpuWithTopologyFlagIT {
42+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
43+
private static final String ZONE = "europe-west4-a";
44+
static String javaVersion = System.getProperty("java.version").substring(0, 2);
45+
private static final String NODE_NAME = "test-tpu-topology-" + javaVersion + "-"
46+
+ UUID.randomUUID().toString().substring(0, 8);
47+
private static final Type ACCELERATOR_TYPE = AcceleratorConfig.Type.V2;
48+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
49+
private static final String TOPOLOGY = "2x2";
50+
51+
public static void requireEnvVar(String envVarName) {
52+
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
53+
.that(System.getenv(envVarName)).isNotEmpty();
54+
}
55+
56+
@BeforeAll
57+
public static void setUp()
58+
throws IOException, ExecutionException, InterruptedException {
59+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
60+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
61+
62+
// Cleanup existing stale resources.
63+
Util.cleanUpExistingTpu("test-tpu-topology-" + javaVersion, PROJECT_ID, ZONE);
64+
}
65+
66+
@AfterAll
67+
public static void cleanup() throws Exception {
68+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
69+
70+
// Test that TPUs is deleted
71+
Assertions.assertThrows(
72+
NotFoundException.class,
73+
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
74+
}
75+
76+
@Test
77+
public void testCreateTpuVmWithTopologyFlag()
78+
throws IOException, ExecutionException, InterruptedException {
79+
Node node = CreateTpuWithTopologyFlag.createTpuWithTopologyFlag(
80+
PROJECT_ID, ZONE, NODE_NAME, ACCELERATOR_TYPE, TPU_SOFTWARE_VERSION, TOPOLOGY);
81+
82+
assertNotNull(node);
83+
assertThat(node.getName().equals(NODE_NAME));
84+
assertThat(node.getAcceleratorConfig().getTopology().equals(TOPOLOGY));
85+
assertThat(node.getAcceleratorConfig().getType().equals(ACCELERATOR_TYPE));
86+
}
87+
}

0 commit comments

Comments
 (0)