Skip to content

Commit ec13f4d

Browse files
Split PR into smaller, deleted redundant code
1 parent f6b76cc commit ec13f4d

File tree

7 files changed

+34
-223
lines changed

7 files changed

+34
-223
lines changed

tpu/src/main/java/tpu/CreateTpuVm.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
package tpu;
1818

1919
//[START tpu_vm_create]
20-
20+
import com.google.api.gax.longrunning.OperationTimedPollAlgorithm;
21+
import com.google.api.gax.retrying.RetrySettings;
2122
import com.google.cloud.tpu.v2.CreateNodeRequest;
2223
import com.google.cloud.tpu.v2.Node;
2324
import com.google.cloud.tpu.v2.TpuClient;
25+
import com.google.cloud.tpu.v2.TpuSettings;
2426
import java.io.IOException;
2527
import java.util.concurrent.ExecutionException;
28+
import org.threeten.bp.Duration;
2629

2730
public class CreateTpuVm {
2831

@@ -49,12 +52,30 @@ public static void main(String[] args)
4952
}
5053

5154
// Creates a TPU VM with the specified name, zone, accelerator type, and version.
52-
public static void createTpuVm(
55+
public static Node createTpuVm(
5356
String projectId, String zone, String nodeName, String tpuType, String tpuSoftwareVersion)
5457
throws IOException, ExecutionException, InterruptedException {
58+
// With these settings the client library handles the Operation's polling mechanism
59+
// and prevent CancellationException error
60+
TpuSettings.Builder clientSettings =
61+
TpuSettings.newBuilder();
62+
clientSettings
63+
.createNodeOperationSettings()
64+
.setPollingAlgorithm(
65+
OperationTimedPollAlgorithm.create(
66+
RetrySettings.newBuilder()
67+
.setInitialRetryDelay(Duration.ofMillis(5000L))
68+
.setRetryDelayMultiplier(1.5)
69+
.setMaxRetryDelay(Duration.ofMillis(45000L))
70+
.setInitialRpcTimeout(Duration.ZERO)
71+
.setRpcTimeoutMultiplier(1.0)
72+
.setMaxRpcTimeout(Duration.ZERO)
73+
.setTotalTimeout(Duration.ofHours(24L))
74+
.build()));
75+
5576
// Initialize client that will be used to send requests. This client only needs to be created
5677
// once, and can be reused for multiple requests.
57-
try (TpuClient tpuClient = TpuClient.create()) {
78+
try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) {
5879
String parent = String.format("projects/%s/locations/%s", projectId, zone);
5980

6081
Node tpuVm = Node.newBuilder()
@@ -71,6 +92,7 @@ public static void createTpuVm(
7192

7293
Node response = tpuClient.createNodeAsync(request).get();
7394
System.out.printf("TPU VM created: %s\n", response.getName());
95+
return response;
7496
}
7597
}
7698
}

tpu/src/main/java/tpu/DeleteTpuVm.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package tpu;
1818

1919
//[START tpu_vm_delete]
20-
2120
import com.google.api.gax.longrunning.OperationTimedPollAlgorithm;
2221
import com.google.api.gax.retrying.RetrySettings;
2322
import com.google.cloud.tpu.v2.DeleteNodeRequest;

tpu/src/main/java/tpu/GetTpuVm.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package tpu;
1818

1919
//[START tpu_vm_get]
20-
2120
import com.google.cloud.tpu.v2.GetNodeRequest;
2221
import com.google.cloud.tpu.v2.Node;
2322
import com.google.cloud.tpu.v2.NodeName;

tpu/src/main/java/tpu/ListTpuVms.java

Lines changed: 0 additions & 53 deletions
This file was deleted.

tpu/src/main/java/tpu/StartTpuVm.java

Lines changed: 0 additions & 60 deletions
This file was deleted.

tpu/src/main/java/tpu/StopTpuVm.java

Lines changed: 0 additions & 60 deletions
This file was deleted.

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,16 @@
1616

1717
package tpu;
1818

19-
import static com.google.cloud.tpu.v2.Node.State.READY;
20-
import static com.google.cloud.tpu.v2.Node.State.STOPPED;
2119
import static com.google.common.truth.Truth.assertThat;
2220
import static com.google.common.truth.Truth.assertWithMessage;
2321
import static org.junit.Assert.assertNotNull;
2422

2523
import com.google.api.gax.rpc.NotFoundException;
2624
import com.google.cloud.tpu.v2.Node;
27-
import com.google.cloud.tpu.v2.TpuClient;
28-
import java.io.ByteArrayOutputStream;
2925
import java.io.IOException;
30-
import java.io.PrintStream;
3126
import java.util.UUID;
3227
import java.util.concurrent.ExecutionException;
3328
import java.util.concurrent.TimeUnit;
34-
import org.junit.Assert;
3529
import org.junit.jupiter.api.AfterAll;
3630
import org.junit.jupiter.api.Assertions;
3731
import org.junit.jupiter.api.BeforeAll;
@@ -44,15 +38,15 @@
4438
import org.junit.runners.JUnit4;
4539

4640
@RunWith(JUnit4.class)
47-
@Timeout(value = 25, unit = TimeUnit.MINUTES)
41+
@Timeout(value = 15, unit = TimeUnit.MINUTES)
4842
@TestMethodOrder(MethodOrderer. OrderAnnotation. class)
4943
public class TpuVmIT {
5044
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
51-
private static final String ZONE = "europe-west4-a";
45+
private static final String ZONE = "us-central1-a";
5246
static String javaVersion = System.getProperty("java.version").substring(0, 2);
5347
private static final String NODE_NAME = "test-tpu-" + javaVersion + "-"
5448
+ UUID.randomUUID().toString().substring(0, 8);
55-
private static final String TPU_TYPE = "v2-8";
49+
private static final String TPU_TYPE = "v3-8";
5650
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
5751
private static final String NODE_PATH_NAME =
5852
String.format("projects/%s/locations/%s/nodes/%s", PROJECT_ID, ZONE, NODE_NAME);
@@ -85,14 +79,13 @@ public static void cleanup() throws Exception {
8579
@Test
8680
@Order(1)
8781
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
88-
final PrintStream out = System.out;
89-
ByteArrayOutputStream stdOut = new ByteArrayOutputStream();
90-
System.setOut(new PrintStream(stdOut));
91-
CreateTpuVm.createTpuVm(PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
9282

93-
assertThat(stdOut.toString()).contains("TPU VM created: " + NODE_PATH_NAME);
94-
stdOut.close();
95-
System.setOut(out);
83+
Node node = CreateTpuVm.createTpuVm(
84+
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
85+
86+
assertNotNull(node);
87+
assertThat(node.getName().equals(NODE_NAME));
88+
assertThat(node.getAcceleratorType().equals(TPU_TYPE));
9689
}
9790

9891
@Test
@@ -103,33 +96,4 @@ public void testGetTpuVm() throws IOException {
10396
assertNotNull(node);
10497
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
10598
}
106-
107-
@Test
108-
@Order(2)
109-
public void testListTpuVm() throws IOException {
110-
TpuClient.ListNodesPagedResponse nodesList = ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
111-
112-
assertNotNull(nodesList);
113-
for (Node node : nodesList.iterateAll()) {
114-
Assert.assertTrue(node.getName().contains("test-tpu"));
115-
}
116-
}
117-
118-
@Test
119-
@Order(2)
120-
public void testStopTpuVm() throws IOException, ExecutionException, InterruptedException {
121-
StopTpuVm.stopTpuVm(PROJECT_ID, ZONE, NODE_NAME);
122-
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
123-
124-
assertThat(node.getState()).isEqualTo(STOPPED);
125-
}
126-
127-
@Test
128-
@Order(3)
129-
public void testStartTpuVm() throws IOException, ExecutionException, InterruptedException {
130-
StartTpuVm.startTpuVm(PROJECT_ID, ZONE, NODE_NAME);
131-
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
132-
133-
assertThat(node.getState()).isEqualTo(READY);
134-
}
13599
}

0 commit comments

Comments
 (0)