Skip to content

Commit bcecf8f

Browse files
feat(tpu): add tpu vm create topology sample. (#9611)
* Changed package, added information to CODEOWNERS * Added information to CODEOWNERS * Added timeout * Fixed parameters for test * Fixed DeleteTpuVm and naming * Added comment, created Util class * Fixed naming * Fixed whitespace * Split PR into smaller, deleted redundant code * Implemented tpu_vm_create_topology sample, created test * Changed zone * Fixed empty lines and tests, deleted cleanup method * Fixed tests * Fixed test * Fixed imports * Increased timeout to 10 sec * Fixed tests * Fixed tests * Deleted settings * Made ByteArrayOutputStream bout as local variable * Changed timeout to 10 sec
1 parent 7f84b30 commit bcecf8f

File tree

5 files changed

+155
-93
lines changed

5 files changed

+155
-93
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_create_topology]
20+
import com.google.cloud.tpu.v2.AcceleratorConfig;
21+
import com.google.cloud.tpu.v2.AcceleratorConfig.Type;
22+
import com.google.cloud.tpu.v2.CreateNodeRequest;
23+
import com.google.cloud.tpu.v2.Node;
24+
import com.google.cloud.tpu.v2.TpuClient;
25+
import java.io.IOException;
26+
import java.util.concurrent.ExecutionException;
27+
28+
public class CreateTpuWithTopologyFlag {
29+
30+
public static void main(String[] args)
31+
throws IOException, ExecutionException, InterruptedException {
32+
// TODO(developer): Replace these variables before running the sample.
33+
// Project ID or project number of the Google Cloud project you want to create a node.
34+
String projectId = "YOUR_PROJECT_ID";
35+
// The zone in which to create the TPU.
36+
// For more information about supported TPU types for specific zones,
37+
// see https://cloud.google.com/tpu/docs/regions-zones
38+
String zone = "europe-west4-a";
39+
// The name for your TPU.
40+
String nodeName = "YOUR_TPU_NAME";
41+
// The version of the Cloud TPU you want to create.
42+
// Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7
43+
Type tpuVersion = AcceleratorConfig.Type.V2;
44+
// Software version that specifies the version of the TPU runtime to install.
45+
// For more information, see https://cloud.google.com/tpu/docs/runtimes
46+
String tpuSoftwareVersion = "tpu-vm-tf-2.17.0-pod-pjrt";
47+
// The physical topology of your TPU slice.
48+
// For more information about topology for each TPU version,
49+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
50+
String topology = "2x2";
51+
52+
createTpuWithTopologyFlag(projectId, zone, nodeName, tpuVersion, tpuSoftwareVersion, topology);
53+
}
54+
55+
// Creates a TPU VM with the specified name, zone, version and topology.
56+
public static Node createTpuWithTopologyFlag(String projectId, String zone, String nodeName,
57+
Type tpuVersion, String tpuSoftwareVersion, String topology)
58+
throws IOException, ExecutionException, InterruptedException {
59+
// Initialize client that will be used to send requests. This client only needs to be created
60+
// once, and can be reused for multiple requests.
61+
try (TpuClient tpuClient = TpuClient.create()) {
62+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
63+
Node tpuVm =
64+
Node.newBuilder()
65+
.setName(nodeName)
66+
.setAcceleratorConfig(Node.newBuilder()
67+
.getAcceleratorConfigBuilder()
68+
.setType(tpuVersion)
69+
.setTopology(topology)
70+
.build())
71+
.setRuntimeVersion(tpuSoftwareVersion)
72+
.build();
73+
74+
CreateNodeRequest request =
75+
CreateNodeRequest.newBuilder()
76+
.setParent(parent)
77+
.setNodeId(nodeName)
78+
.setNode(tpuVm)
79+
.build();
80+
81+
return tpuClient.createNodeAsync(request).get();
82+
}
83+
}
84+
}
85+
//[END tpu_vm_create_topology]

tpu/src/main/java/tpu/GetQueuedResource.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package tpu;
1818

1919
//[START tpu_queued_resources_get]
20-
2120
import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
2221
import com.google.cloud.tpu.v2alpha1.QueuedResource;
2322
import com.google.cloud.tpu.v2alpha1.TpuClient;

tpu/src/test/java/tpu/CreateTpuIT.java

Lines changed: 0 additions & 70 deletions
This file was deleted.

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@
3535
import java.io.ByteArrayOutputStream;
3636
import java.io.IOException;
3737
import java.io.PrintStream;
38-
import org.junit.Before;
39-
import org.junit.Test;
38+
import org.junit.jupiter.api.BeforeAll;
39+
import org.junit.jupiter.api.Test;
4040
import org.junit.jupiter.api.Timeout;
4141
import org.junit.runner.RunWith;
4242
import org.junit.runners.JUnit4;
4343
import org.mockito.MockedStatic;
4444

4545
@RunWith(JUnit4.class)
46-
@Timeout(value = 3)
46+
@Timeout(value = 10)
4747
public class QueuedResourceIT {
4848
private static final String PROJECT_ID = "project-id";
4949
private static final String ZONE = "europe-west4-a";
@@ -52,10 +52,10 @@ public class QueuedResourceIT {
5252
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
5353
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
5454
private static final String NETWORK_NAME = "default";
55-
private ByteArrayOutputStream bout;
55+
private static ByteArrayOutputStream bout;
5656

57-
@Before
58-
public void setUp() {
57+
@BeforeAll
58+
public static void setUp() {
5959
bout = new ByteArrayOutputStream();
6060
System.setOut(new PrintStream(bout));
6161
}
@@ -75,8 +75,8 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
7575

7676
QueuedResource returnedQueuedResource =
7777
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
78-
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
79-
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
78+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
79+
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
8080

8181
verify(mockTpuClient, times(1))
8282
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
@@ -89,7 +89,6 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
8989
public void testGetQueuedResource() throws IOException {
9090
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
9191
TpuClient mockClient = mock(TpuClient.class);
92-
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
9392
QueuedResource mockQueuedResource = mock(QueuedResource.class);
9493

9594
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
@@ -99,14 +98,14 @@ public void testGetQueuedResource() throws IOException {
9998
QueuedResource returnedQueuedResource =
10099
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
101100

102-
verify(mockGetQueuedResource, times(1))
103-
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
101+
verify(mockClient, times(1))
102+
.getQueuedResource(any(GetQueuedResourceRequest.class));
104103
assertEquals(returnedQueuedResource, mockQueuedResource);
105104
}
106105
}
107106

108107
@Test
109-
public void testDeleteTpuVm() {
108+
public void testDeleteForceQueuedResource() {
110109
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
111110
TpuClient mockTpuClient = mock(TpuClient.class);
112111
OperationFuture mockFuture = mock(OperationFuture.class);

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertEquals;
2021
import static org.mockito.Mockito.any;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.mockStatic;
@@ -25,6 +26,8 @@
2526
import static org.mockito.Mockito.when;
2627

2728
import com.google.api.gax.longrunning.OperationFuture;
29+
import com.google.cloud.tpu.v2.AcceleratorConfig;
30+
import com.google.cloud.tpu.v2.CreateNodeRequest;
2831
import com.google.cloud.tpu.v2.DeleteNodeRequest;
2932
import com.google.cloud.tpu.v2.GetNodeRequest;
3033
import com.google.cloud.tpu.v2.Node;
@@ -34,47 +37,68 @@
3437
import java.io.IOException;
3538
import java.io.PrintStream;
3639
import java.util.concurrent.ExecutionException;
37-
import org.junit.jupiter.api.BeforeAll;
3840
import org.junit.jupiter.api.Test;
3941
import org.junit.jupiter.api.Timeout;
4042
import org.junit.runner.RunWith;
4143
import org.junit.runners.JUnit4;
4244
import org.mockito.MockedStatic;
4345

4446
@RunWith(JUnit4.class)
45-
@Timeout(value = 3)
47+
@Timeout(value = 10)
4648
public class TpuVmIT {
4749
private static final String PROJECT_ID = "project-id";
4850
private static final String ZONE = "asia-east1-c";
4951
private static final String NODE_NAME = "test-tpu";
50-
private static ByteArrayOutputStream bout;
52+
private static final String TPU_TYPE = "v2-8";
53+
private static final AcceleratorConfig.Type ACCELERATOR_TYPE = AcceleratorConfig.Type.V2;
54+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
55+
private static final String TOPOLOGY = "2x2";
5156

52-
@BeforeAll
53-
public static void setUp() {
54-
bout = new ByteArrayOutputStream();
55-
System.setOut(new PrintStream(bout));
57+
@Test
58+
public void testCreateTpuVm() throws Exception {
59+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
60+
Node mockNode = mock(Node.class);
61+
TpuClient mockTpuClient = mock(TpuClient.class);
62+
OperationFuture mockFuture = mock(OperationFuture.class);
63+
64+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
65+
.thenReturn(mockTpuClient);
66+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
67+
.thenReturn(mockFuture);
68+
when(mockFuture.get()).thenReturn(mockNode);
69+
70+
Node returnedNode = CreateTpuVm.createTpuVm(
71+
PROJECT_ID, ZONE, NODE_NAME,
72+
TPU_TYPE, TPU_SOFTWARE_VERSION);
73+
74+
verify(mockTpuClient, times(1))
75+
.createNodeAsync(any(CreateNodeRequest.class));
76+
verify(mockFuture, times(1)).get();
77+
assertEquals(returnedNode, mockNode);
78+
}
5679
}
5780

5881
@Test
5982
public void testGetTpuVm() throws IOException {
6083
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
6184
Node mockNode = mock(Node.class);
6285
TpuClient mockClient = mock(TpuClient.class);
63-
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
6486

6587
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
6688
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
6789

6890
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
6991

70-
verify(mockGetTpuVm, times(1))
71-
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
92+
verify(mockClient, times(1))
93+
.getNode(any(GetNodeRequest.class));
7294
assertThat(returnedNode).isEqualTo(mockNode);
7395
}
7496
}
7597

7698
@Test
7799
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
100+
ByteArrayOutputStream bout = new ByteArrayOutputStream();
101+
System.setOut(new PrintStream(bout));
78102
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
79103
TpuClient mockTpuClient = mock(TpuClient.class);
80104
OperationFuture mockFuture = mock(OperationFuture.class);
@@ -89,6 +113,31 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte
89113

90114
assertThat(output).contains("TPU VM deleted");
91115
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
116+
117+
bout.close();
118+
}
119+
}
120+
121+
@Test
122+
public void testCreateTpuVmWithTopologyFlag()
123+
throws IOException, ExecutionException, InterruptedException {
124+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
125+
Node mockNode = mock(Node.class);
126+
TpuClient mockTpuClient = mock(TpuClient.class);
127+
OperationFuture mockFuture = mock(OperationFuture.class);
128+
129+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
130+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
131+
.thenReturn(mockFuture);
132+
when(mockFuture.get()).thenReturn(mockNode);
133+
Node returnedNode = CreateTpuWithTopologyFlag.createTpuWithTopologyFlag(
134+
PROJECT_ID, ZONE, NODE_NAME, ACCELERATOR_TYPE,
135+
TPU_SOFTWARE_VERSION, TOPOLOGY);
136+
137+
verify(mockTpuClient, times(1))
138+
.createNodeAsync(any(CreateNodeRequest.class));
139+
verify(mockFuture, times(1)).get();
140+
assertEquals(returnedNode, mockNode);
92141
}
93142
}
94143
}

0 commit comments

Comments
 (0)