Skip to content

Commit 08aee80

Browse files
feat(tpu): add tpu queued resources startup script sample (#9604)
* Added tpu_queued_resources_network sample * Fixed samples and tests * Fixed tests * Changed CODEOWNERS * Split samples, fixed startup script path * Fixed style * Added tag * Added header * Implemented tpu_queued_resources_startup_script sample, created test * Fixed test, deleted cleanup method * Fixed test, deleted cleanup method * Fixed test * Fixed naming * Changed zone * Fixed tests * Fixed tests * Increased timeout * Fixed code as requested in comments * Deleted settings * Fixed test
1 parent 3607955 commit 08aee80

File tree

3 files changed

+141
-22
lines changed

3 files changed

+141
-22
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_queued_resources_startup_script]
20+
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
21+
import com.google.cloud.tpu.v2alpha1.Node;
22+
import com.google.cloud.tpu.v2alpha1.QueuedResource;
23+
import com.google.cloud.tpu.v2alpha1.TpuClient;
24+
import java.io.IOException;
25+
import java.util.HashMap;
26+
import java.util.Map;
27+
import java.util.concurrent.ExecutionException;
28+
29+
public class CreateQueuedResourceWithStartupScript {
30+
public static void main(String[] args)
31+
throws IOException, ExecutionException, InterruptedException {
32+
// TODO(developer): Replace these variables before running the sample.
33+
// Project ID or project number of the Google Cloud project you want to create a node.
34+
String projectId = "YOUR_PROJECT_ID";
35+
// The zone in which to create the TPU.
36+
// For more information about supported TPU types for specific zones,
37+
// see https://cloud.google.com/tpu/docs/regions-zones
38+
String zone = "us-central1-a";
39+
// The name for your TPU.
40+
String nodeName = "YOUR_TPU_NAME";
41+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
42+
// For more information about supported accelerator types for each TPU version,
43+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
44+
String tpuType = "v2-8";
45+
// Software version that specifies the version of the TPU runtime to install.
46+
// For more information see https://cloud.google.com/tpu/docs/runtimes
47+
String tpuSoftwareVersion = "tpu-vm-tf-2.14.1";
48+
// The name for your Queued Resource.
49+
String queuedResourceId = "QUEUED_RESOURCE_ID";
50+
51+
createQueuedResource(projectId, zone, queuedResourceId, nodeName,
52+
tpuType, tpuSoftwareVersion);
53+
}
54+
55+
// Creates a Queued Resource with startup script.
56+
public static QueuedResource createQueuedResource(
57+
String projectId, String zone, String queuedResourceId,
58+
String nodeName, String tpuType, String tpuSoftwareVersion)
59+
throws IOException, ExecutionException, InterruptedException {
60+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
61+
String startupScriptContent = "#!/bin/bash\necho \"Hello from the startup script!\"";
62+
// Add startup script to metadata
63+
Map<String, String> metadata = new HashMap<>();
64+
metadata.put("startup-script", startupScriptContent);
65+
String queuedResourceForTpu = String.format("projects/%s/locations/%s/queuedResources/%s",
66+
projectId, zone, queuedResourceId);
67+
// Initialize client that will be used to send requests. This client only needs to be created
68+
// once, and can be reused for multiple requests.
69+
try (TpuClient tpuClient = TpuClient.create()) {
70+
Node node =
71+
Node.newBuilder()
72+
.setName(nodeName)
73+
.setAcceleratorType(tpuType)
74+
.setRuntimeVersion(tpuSoftwareVersion)
75+
.setQueuedResource(queuedResourceForTpu)
76+
.putAllMetadata(metadata)
77+
.build();
78+
79+
QueuedResource queuedResource =
80+
QueuedResource.newBuilder()
81+
.setName(queuedResourceId)
82+
.setTpu(
83+
QueuedResource.Tpu.newBuilder()
84+
.addNodeSpec(
85+
QueuedResource.Tpu.NodeSpec.newBuilder()
86+
.setParent(parent)
87+
.setNode(node)
88+
.setNodeId(nodeName)
89+
.build())
90+
.build())
91+
.build();
92+
93+
CreateQueuedResourceRequest request =
94+
CreateQueuedResourceRequest.newBuilder()
95+
.setParent(parent)
96+
.setQueuedResourceId(queuedResourceId)
97+
.setQueuedResource(queuedResource)
98+
.build();
99+
// You can wait until TPU Node is READY,
100+
// and check its status using getTpuVm() from "tpu_vm_get" sample.
101+
102+
return tpuClient.createQueuedResourceAsync(request).get();
103+
}
104+
}
105+
}
106+
// [END tpu_queued_resources_startup_script]

tpu/src/main/java/tpu/DeleteForceQueuedResource.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
//[START tpu_queued_resources_delete_force]
2020
import com.google.api.gax.retrying.RetrySettings;
21-
import com.google.api.gax.rpc.UnknownException;
2221
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
2322
import com.google.cloud.tpu.v2alpha1.TpuClient;
2423
import com.google.cloud.tpu.v2alpha1.TpuSettings;
@@ -27,12 +26,13 @@
2726
import org.threeten.bp.Duration;
2827

2928
public class DeleteForceQueuedResource {
30-
public static void main(String[] args) {
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
3131
// TODO(developer): Replace these variables before running the sample.
3232
// Project ID or project number of the Google Cloud project.
3333
String projectId = "YOUR_PROJECT_ID";
3434
// The zone in which the TPU was created.
35-
String zone = "europe-west4-a";
35+
String zone = "us-central1-f";
3636
// The name for your Queued Resource.
3737
String queuedResourceId = "QUEUED_RESOURCE_ID";
3838

@@ -41,7 +41,8 @@ public static void main(String[] args) {
4141

4242
// Deletes a Queued Resource asynchronously with --force flag.
4343
public static void deleteForceQueuedResource(
44-
String projectId, String zone, String queuedResourceId) {
44+
String projectId, String zone, String queuedResourceId)
45+
throws ExecutionException, InterruptedException, IOException {
4546
String name = String.format("projects/%s/locations/%s/queuedResources/%s",
4647
projectId, zone, queuedResourceId);
4748
// With these settings the client library handles the Operation's polling mechanism
@@ -65,13 +66,12 @@ public static void deleteForceQueuedResource(
6566
try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) {
6667
DeleteQueuedResourceRequest request =
6768
DeleteQueuedResourceRequest.newBuilder().setName(name).setForce(true).build();
68-
69+
// Waiting for updates in the library. Until then, the operation will complete successfully,
70+
// but the user will receive an error message with UnknownException and IllegalStateException.
6971
tpuClient.deleteQueuedResourceAsync(request).get();
7072

71-
} catch (UnknownException | InterruptedException | ExecutionException | IOException e) {
72-
System.out.println(e.getMessage());
73+
System.out.printf("Deleted Queued Resource: %s\n", name);
7374
}
74-
System.out.printf("Deleted Queued Resource: %s\n", name);
7575
}
7676
}
7777
//[END tpu_queued_resources_delete_force]

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package tpu;
1818

19-
import static com.google.common.truth.Truth.assertThat;
2019
import static org.junit.Assert.assertEquals;
2120
import static org.mockito.Mockito.any;
2221
import static org.mockito.Mockito.mock;
@@ -32,10 +31,8 @@
3231
import com.google.cloud.tpu.v2alpha1.QueuedResource;
3332
import com.google.cloud.tpu.v2alpha1.TpuClient;
3433
import com.google.cloud.tpu.v2alpha1.TpuSettings;
35-
import java.io.ByteArrayOutputStream;
3634
import java.io.IOException;
37-
import java.io.PrintStream;
38-
import org.junit.jupiter.api.BeforeAll;
35+
import java.util.concurrent.ExecutionException;
3936
import org.junit.jupiter.api.Test;
4037
import org.junit.jupiter.api.Timeout;
4138
import org.junit.runner.RunWith;
@@ -52,13 +49,6 @@ public class QueuedResourceIT {
5249
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
5350
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
5451
private static final String NETWORK_NAME = "default";
55-
private static ByteArrayOutputStream bout;
56-
57-
@BeforeAll
58-
public static void setUp() {
59-
bout = new ByteArrayOutputStream();
60-
System.setOut(new PrintStream(bout));
61-
}
6252

6353
@Test
6454
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
@@ -105,7 +95,8 @@ public void testGetQueuedResource() throws IOException {
10595
}
10696

10797
@Test
108-
public void testDeleteForceQueuedResource() {
98+
public void testDeleteForceQueuedResource()
99+
throws IOException, InterruptedException, ExecutionException {
109100
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
110101
TpuClient mockTpuClient = mock(TpuClient.class);
111102
OperationFuture mockFuture = mock(OperationFuture.class);
@@ -116,11 +107,33 @@ public void testDeleteForceQueuedResource() {
116107
.thenReturn(mockFuture);
117108

118109
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
119-
String output = bout.toString();
120110

121-
assertThat(output).contains("Deleted Queued Resource:");
122111
verify(mockTpuClient, times(1))
123112
.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class));
124113
}
125114
}
115+
116+
@Test
117+
public void testCreateQueuedResourceWithStartupScript() throws Exception {
118+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
119+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
120+
TpuClient mockTpuClient = mock(TpuClient.class);
121+
OperationFuture mockFuture = mock(OperationFuture.class);
122+
123+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
124+
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
125+
.thenReturn(mockFuture);
126+
when(mockFuture.get()).thenReturn(mockQueuedResource);
127+
128+
QueuedResource returnedQueuedResource =
129+
CreateQueuedResourceWithStartupScript.createQueuedResource(
130+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
131+
TPU_TYPE, TPU_SOFTWARE_VERSION);
132+
133+
verify(mockTpuClient, times(1))
134+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
135+
verify(mockFuture, times(1)).get();
136+
assertEquals(returnedQueuedResource, mockQueuedResource);
137+
}
138+
}
126139
}

0 commit comments

Comments
 (0)