Skip to content

Commit 9de8f3a

Browse files
Implemented tpu_vm_stop and tpu_vm_start samples, created tests
1 parent ec13f4d commit 9de8f3a

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_start]
20+
import com.google.cloud.tpu.v2.Node;
21+
import com.google.cloud.tpu.v2.NodeName;
22+
import com.google.cloud.tpu.v2.StartNodeRequest;
23+
import com.google.cloud.tpu.v2.TpuClient;
24+
import java.io.IOException;
25+
import java.util.concurrent.ExecutionException;
26+
27+
public class StartTpuVm {
28+
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// Project ID or project number of the Google Cloud project you want to create a node.
33+
String projectId = "YOUR_PROJECT_ID";
34+
// The zone in which to create the TPU.
35+
// For more information about supported TPU types for specific zones,
36+
// see https://cloud.google.com/tpu/docs/regions-zones
37+
String zone = "europe-west4-a";
38+
// The name for your TPU.
39+
String nodeName = "YOUR_TPY_NAME";
40+
41+
startTpuVm(projectId, zone, nodeName);
42+
}
43+
44+
// Starts a TPU VM with the specified name in the given project and zone.
45+
public static void startTpuVm(String projectId, String zone, String nodeName)
46+
throws IOException, ExecutionException, InterruptedException {
47+
// Initialize client that will be used to send requests. This client only needs to be created
48+
// once, and can be reused for multiple requests.
49+
try (TpuClient tpuClient = TpuClient.create()) {
50+
String name = NodeName.of(projectId, zone, nodeName).toString();
51+
52+
StartNodeRequest request = StartNodeRequest.newBuilder().setName(name).build();
53+
Node response = tpuClient.startNodeAsync(request).get();
54+
55+
System.out.printf("TPU VM started: %s\n", response.getName());
56+
}
57+
}
58+
}
59+
//[END tpu_vm_start]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_stop]
20+
import com.google.cloud.tpu.v2.Node;
21+
import com.google.cloud.tpu.v2.NodeName;
22+
import com.google.cloud.tpu.v2.StopNodeRequest;
23+
import com.google.cloud.tpu.v2.TpuClient;
24+
import java.io.IOException;
25+
import java.util.concurrent.ExecutionException;
26+
27+
public class StopTpuVm {
28+
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// Project ID or project number of the Google Cloud project you want to create a node.
33+
String projectId = "YOUR_PROJECT_ID";
34+
// The zone in which to create the TPU.
35+
// For more information about supported TPU types for specific zones,
36+
// see https://cloud.google.com/tpu/docs/regions-zones
37+
String zone = "europe-west4-a";
38+
// The name for your TPU.
39+
String nodeName = "YOUR_TPY_NAME";
40+
41+
stopTpuVm(projectId, zone, nodeName);
42+
}
43+
44+
// Stops a TPU VM with the specified name in the given project and zone.
45+
public static void stopTpuVm(String projectId, String zone, String nodeName)
46+
throws IOException, ExecutionException, InterruptedException {
47+
// Initialize client that will be used to send requests. This client only needs to be created
48+
// once, and can be reused for multiple requests.
49+
try (TpuClient tpuClient = TpuClient.create()) {
50+
String name = NodeName.of(projectId, zone, nodeName).toString();
51+
52+
StopNodeRequest request = StopNodeRequest.newBuilder().setName(name).build();
53+
Node response = tpuClient.stopNodeAsync(request).get();
54+
55+
System.out.printf("TPU VM stopped: %s\n", response.getName());
56+
}
57+
}
58+
}
59+
//[END tpu_vm_stop]
60+

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package tpu;
1818

19+
import static com.google.cloud.tpu.v2.Node.State.READY;
20+
import static com.google.cloud.tpu.v2.Node.State.STOPPED;
1921
import static com.google.common.truth.Truth.assertThat;
2022
import static com.google.common.truth.Truth.assertWithMessage;
2123
import static org.junit.Assert.assertNotNull;
@@ -96,4 +98,23 @@ public void testGetTpuVm() throws IOException {
9698
assertNotNull(node);
9799
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
98100
}
101+
102+
103+
@Test
104+
@Order(2)
105+
public void testStopTpuVm() throws IOException, ExecutionException, InterruptedException {
106+
StopTpuVm.stopTpuVm(PROJECT_ID, ZONE, NODE_NAME);
107+
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
108+
109+
assertThat(node.getState()).isEqualTo(STOPPED);
110+
}
111+
112+
@Test
113+
@Order(3)
114+
public void testStartTpuVm() throws IOException, ExecutionException, InterruptedException {
115+
StartTpuVm.startTpuVm(PROJECT_ID, ZONE, NODE_NAME);
116+
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
117+
118+
assertThat(node.getState()).isEqualTo(READY);
119+
}
99120
}

0 commit comments

Comments
 (0)