Skip to content

Commit d81e48a

Browse files
Fixed test
1 parent 41ce693 commit d81e48a

File tree

3 files changed

+184
-146
lines changed

3 files changed

+184
-146
lines changed

tpu/src/test/java/tpu/CreateTpuWithTopologyFlagIT.java

Lines changed: 0 additions & 74 deletions
This file was deleted.

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,56 +17,111 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
20+
import static org.junit.Assert.assertEquals;
21+
import static org.mockito.Mockito.any;
22+
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.mockStatic;
24+
import static org.mockito.Mockito.times;
25+
import static org.mockito.Mockito.verify;
26+
import static org.mockito.Mockito.when;
2127

28+
import com.google.api.gax.longrunning.OperationFuture;
29+
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
30+
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
31+
import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
2232
import com.google.cloud.tpu.v2alpha1.QueuedResource;
23-
import java.util.UUID;
24-
import java.util.concurrent.TimeUnit;
25-
import org.junit.Test;
26-
import org.junit.jupiter.api.AfterAll;
33+
import com.google.cloud.tpu.v2alpha1.TpuClient;
34+
import com.google.cloud.tpu.v2alpha1.TpuSettings;
35+
import java.io.ByteArrayOutputStream;
36+
import java.io.IOException;
37+
import java.io.PrintStream;
2738
import org.junit.jupiter.api.BeforeAll;
39+
import org.junit.jupiter.api.Test;
2840
import org.junit.jupiter.api.Timeout;
2941
import org.junit.runner.RunWith;
3042
import org.junit.runners.JUnit4;
43+
import org.mockito.MockedStatic;
3144

3245
@RunWith(JUnit4.class)
33-
@Timeout(value = 6, unit = TimeUnit.MINUTES)
46+
@Timeout(value = 3)
3447
public class QueuedResourceIT {
35-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
48+
private static final String PROJECT_ID = "project-id";
3649
private static final String ZONE = "europe-west4-a";
37-
private static final String NODE_NAME = "test-tpu-queued-resource-network-" + UUID.randomUUID();
50+
private static final String NODE_NAME = "test-tpu";
3851
private static final String TPU_TYPE = "v2-8";
3952
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
40-
private static final String QUEUED_RESOURCE_NAME = "queued-resource-network-" + UUID.randomUUID();
53+
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
4154
private static final String NETWORK_NAME = "default";
42-
43-
public static void requireEnvVar(String envVarName) {
44-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
45-
.that(System.getenv(envVarName)).isNotEmpty();
46-
}
55+
private static ByteArrayOutputStream bout;
4756

4857
@BeforeAll
4958
public static void setUp() {
50-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
51-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
59+
bout = new ByteArrayOutputStream();
60+
System.setOut(new PrintStream(bout));
5261
}
5362

54-
@AfterAll
55-
public static void cleanup() {
56-
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
63+
@Test
64+
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
65+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
66+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
67+
TpuClient mockTpuClient = mock(TpuClient.class);
68+
OperationFuture mockFuture = mock(OperationFuture.class);
69+
70+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
71+
.thenReturn(mockTpuClient);
72+
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
73+
.thenReturn(mockFuture);
74+
when(mockFuture.get()).thenReturn(mockQueuedResource);
75+
76+
QueuedResource returnedQueuedResource =
77+
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
78+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
79+
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
80+
81+
verify(mockTpuClient, times(1))
82+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
83+
verify(mockFuture, times(1)).get();
84+
assertEquals(returnedQueuedResource, mockQueuedResource);
85+
}
5786
}
5887

5988
@Test
60-
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
89+
public void testGetQueuedResource() throws IOException {
90+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
91+
TpuClient mockClient = mock(TpuClient.class);
92+
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
93+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
94+
95+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
96+
when(mockClient.getQueuedResource(any(GetQueuedResourceRequest.class)))
97+
.thenReturn(mockQueuedResource);
98+
99+
QueuedResource returnedQueuedResource =
100+
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
101+
102+
verify(mockGetQueuedResource, times(1))
103+
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
104+
assertEquals(returnedQueuedResource, mockQueuedResource);
105+
}
106+
}
107+
108+
@Test
109+
public void testDeleteForceQueuedResource() {
110+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
111+
TpuClient mockTpuClient = mock(TpuClient.class);
112+
OperationFuture mockFuture = mock(OperationFuture.class);
113+
114+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
115+
.thenReturn(mockTpuClient);
116+
when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)))
117+
.thenReturn(mockFuture);
61118

62-
QueuedResource queuedResource = CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
63-
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
64-
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
119+
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
120+
String output = bout.toString();
65121

66-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getName()).isEqualTo(NODE_NAME);
67-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getNetwork()
68-
.contains(NETWORK_NAME));
69-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getSubnetwork()
70-
.contains(NETWORK_NAME));
122+
assertThat(output).contains("Deleted Queued Resource:");
123+
verify(mockTpuClient, times(1))
124+
.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class));
125+
}
71126
}
72127
}

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 101 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,76 +17,133 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
21-
import static org.junit.Assert.assertNotNull;
20+
import static org.junit.Assert.assertEquals;
21+
import static org.mockito.Mockito.any;
22+
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.mockStatic;
24+
import static org.mockito.Mockito.times;
25+
import static org.mockito.Mockito.verify;
26+
import static org.mockito.Mockito.when;
2227

23-
import com.google.api.gax.rpc.NotFoundException;
28+
import com.google.api.gax.longrunning.OperationFuture;
29+
import com.google.cloud.tpu.v2.AcceleratorConfig;
30+
import com.google.cloud.tpu.v2.CreateNodeRequest;
31+
import com.google.cloud.tpu.v2.DeleteNodeRequest;
32+
import com.google.cloud.tpu.v2.GetNodeRequest;
2433
import com.google.cloud.tpu.v2.Node;
34+
import com.google.cloud.tpu.v2.TpuClient;
35+
import com.google.cloud.tpu.v2.TpuSettings;
36+
import java.io.ByteArrayOutputStream;
2537
import java.io.IOException;
26-
import java.util.UUID;
38+
import java.io.PrintStream;
2739
import java.util.concurrent.ExecutionException;
28-
import java.util.concurrent.TimeUnit;
29-
import org.junit.jupiter.api.AfterAll;
30-
import org.junit.jupiter.api.Assertions;
3140
import org.junit.jupiter.api.BeforeAll;
32-
import org.junit.jupiter.api.MethodOrderer;
33-
import org.junit.jupiter.api.Order;
3441
import org.junit.jupiter.api.Test;
35-
import org.junit.jupiter.api.TestMethodOrder;
3642
import org.junit.jupiter.api.Timeout;
3743
import org.junit.runner.RunWith;
3844
import org.junit.runners.JUnit4;
45+
import org.mockito.MockedStatic;
3946

4047
@RunWith(JUnit4.class)
41-
@Timeout(value = 15, unit = TimeUnit.MINUTES)
42-
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
48+
@Timeout(value = 3)
4349
public class TpuVmIT {
44-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
50+
private static final String PROJECT_ID = "project-id";
4551
private static final String ZONE = "asia-east1-c";
46-
private static final String NODE_NAME = "test-tpu-" + UUID.randomUUID();
52+
private static final String NODE_NAME = "test-tpu";
4753
private static final String TPU_TYPE = "v2-8";
48-
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
49-
private static final String NODE_PATH_NAME =
50-
String.format("projects/%s/locations/%s/nodes/%s", PROJECT_ID, ZONE, NODE_NAME);
51-
52-
public static void requireEnvVar(String envVarName) {
53-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
54-
.that(System.getenv(envVarName)).isNotEmpty();
55-
}
54+
private static final AcceleratorConfig.Type ACCELERATOR_TYPE = AcceleratorConfig.Type.V2;
55+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
56+
private static final String TOPOLOGY = "2x2";
57+
private static ByteArrayOutputStream bout;
5658

5759
@BeforeAll
5860
public static void setUp() {
59-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
60-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
61+
bout = new ByteArrayOutputStream();
62+
System.setOut(new PrintStream(bout));
6163
}
6264

63-
@AfterAll
64-
public static void cleanup() throws Exception {
65-
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
65+
@Test
66+
public void testCreateTpuVm() throws Exception {
67+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
68+
Node mockNode = mock(Node.class);
69+
TpuClient mockTpuClient = mock(TpuClient.class);
70+
OperationFuture mockFuture = mock(OperationFuture.class);
71+
72+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
73+
.thenReturn(mockTpuClient);
74+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
75+
.thenReturn(mockFuture);
76+
when(mockFuture.get()).thenReturn(mockNode);
77+
78+
Node returnedNode = CreateTpuVm.createTpuVm(
79+
PROJECT_ID, ZONE, NODE_NAME,
80+
TPU_TYPE, TPU_SOFTWARE_VERSION);
6681

67-
// Test that TPUs is deleted
68-
Assertions.assertThrows(
69-
NotFoundException.class,
70-
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
82+
verify(mockTpuClient, times(1))
83+
.createNodeAsync(any(CreateNodeRequest.class));
84+
verify(mockFuture, times(1)).get();
85+
assertEquals(returnedNode, mockNode);
86+
}
7187
}
7288

7389
@Test
74-
@Order(1)
75-
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
76-
Node node = CreateTpuVm.createTpuVm(
77-
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
78-
79-
assertNotNull(node);
80-
assertThat(node.getName().equals(NODE_NAME));
81-
assertThat(node.getAcceleratorType().equals(TPU_TYPE));
90+
public void testGetTpuVm() throws IOException {
91+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
92+
Node mockNode = mock(Node.class);
93+
TpuClient mockClient = mock(TpuClient.class);
94+
95+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
96+
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
97+
98+
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
99+
100+
verify(mockClient, times(1))
101+
.getNode(any(GetNodeRequest.class));
102+
assertThat(returnedNode).isEqualTo(mockNode);
103+
verify(mockClient, times(1)).close();
104+
}
82105
}
83106

84107
@Test
85-
@Order(2)
86-
public void testGetTpuVm() throws IOException {
87-
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
108+
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
109+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
110+
TpuClient mockTpuClient = mock(TpuClient.class);
111+
OperationFuture mockFuture = mock(OperationFuture.class);
112+
113+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
114+
.thenReturn(mockTpuClient);
115+
when(mockTpuClient.deleteNodeAsync(any(DeleteNodeRequest.class)))
116+
.thenReturn(mockFuture);
117+
118+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
119+
String output = bout.toString();
120+
121+
assertThat(output).contains("TPU VM deleted");
122+
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
123+
}
124+
}
125+
126+
@Test
127+
public void testCreateTpuVmWithTopologyFlag()
128+
throws IOException, ExecutionException, InterruptedException {
129+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
130+
Node mockNode = mock(Node.class);
131+
TpuClient mockTpuClient = mock(TpuClient.class);
132+
OperationFuture mockFuture = mock(OperationFuture.class);
133+
134+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
135+
.thenReturn(mockTpuClient);
136+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
137+
.thenReturn(mockFuture);
138+
when(mockFuture.get()).thenReturn(mockNode);
139+
Node returnedNode = CreateTpuWithTopologyFlag.createTpuWithTopologyFlag(
140+
PROJECT_ID, ZONE, NODE_NAME, ACCELERATOR_TYPE,
141+
TPU_SOFTWARE_VERSION, TOPOLOGY);
88142

89-
assertNotNull(node);
90-
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
143+
verify(mockTpuClient, times(1))
144+
.createNodeAsync(any(CreateNodeRequest.class));
145+
verify(mockFuture, times(1)).get();
146+
assertEquals(returnedNode, mockNode);
147+
}
91148
}
92149
}

0 commit comments

Comments
 (0)