Skip to content

Commit 9a45e2e

Browse files
Changed tests
1 parent f89f6fa commit 9a45e2e

File tree

3 files changed

+129
-73
lines changed

3 files changed

+129
-73
lines changed

tpu/pom.xml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@
8484
<version>5.13.0</version>
8585
<scope>test</scope>
8686
</dependency>
87+
<dependency>
88+
<groupId>org.powermock</groupId>
89+
<artifactId>powermock-module-junit4</artifactId>
90+
<version>2.0.9</version>
91+
<scope>test</scope>
92+
</dependency>
93+
<dependency>
94+
<groupId>org.mockito</groupId>
95+
<artifactId>mockito-inline</artifactId>
96+
<version>5.2.0</version>
97+
<scope>test</scope>
98+
</dependency>
99+
87100
</dependencies>
88101

89102
<dependencyManagement>

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,57 +16,90 @@
1616

1717
package tpu;
1818

19-
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
2119

20+
import static org.mockito.Mockito.any;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.times;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
25+
26+
import com.google.api.gax.longrunning.OperationFuture;
27+
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
28+
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
2229
import com.google.cloud.tpu.v2alpha1.QueuedResource;
23-
import java.util.UUID;
30+
import com.google.cloud.tpu.v2alpha1.QueuedResourceName;
31+
import com.google.cloud.tpu.v2alpha1.TpuClient;
32+
import com.google.cloud.tpu.v2alpha1.TpuSettings;
33+
import java.io.IOException;
2434
import java.util.concurrent.TimeUnit;
2535
import org.junit.Test;
26-
import org.junit.jupiter.api.AfterAll;
27-
import org.junit.jupiter.api.BeforeAll;
2836
import org.junit.jupiter.api.Timeout;
2937
import org.junit.runner.RunWith;
3038
import org.junit.runners.JUnit4;
39+
import org.mockito.MockedStatic;
40+
import org.mockito.Mockito;
3141

3242
@RunWith(JUnit4.class)
33-
@Timeout(value = 6, unit = TimeUnit.MINUTES)
43+
@Timeout(value = 3, unit = TimeUnit.MINUTES)
3444
public class QueuedResourceIT {
35-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
45+
private static final String PROJECT_ID = "project-id";
3646
private static final String ZONE = "europe-west4-a";
37-
private static final String NODE_NAME = "test-tpu-queued-resource-network-" + UUID.randomUUID();
47+
private static final String NODE_NAME = "test-tpu";
3848
private static final String TPU_TYPE = "v2-8";
3949
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
40-
private static final String QUEUED_RESOURCE_NAME = "queued-resource-network-" + UUID.randomUUID();
50+
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
4151
private static final String NETWORK_NAME = "default";
4252

43-
public static void requireEnvVar(String envVarName) {
44-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
45-
.that(System.getenv(envVarName)).isNotEmpty();
46-
}
53+
@Test
54+
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
55+
TpuClient mockTpuClient = mock(TpuClient.class);
56+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
57+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
58+
.thenReturn(mockTpuClient);
59+
60+
OperationFuture mockFuture = mock(OperationFuture.class);
61+
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
62+
.thenReturn(mockFuture);
63+
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
64+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
65+
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
4766

48-
@BeforeAll
49-
public static void setUp() {
50-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
51-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
67+
verify(mockTpuClient, times(1))
68+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
69+
}
5270
}
5371

54-
@AfterAll
55-
public static void cleanup() {
56-
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
72+
@Test
73+
public void testGetQueuedResource() throws IOException {
74+
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
75+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
76+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
77+
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
78+
when(mock(TpuClient.class)
79+
.getQueuedResource(any(QueuedResourceName.class))).thenReturn(mockQueuedResource);
80+
81+
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
82+
83+
// Assertions
84+
verify(mockGetQueuedResource, times(1))
85+
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
86+
}
5787
}
5888

5989
@Test
60-
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
90+
public void testDeleteTpuVm() {
91+
TpuClient mockTpuClient = mock(TpuClient.class);
92+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
93+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
94+
.thenReturn(mockTpuClient);
6195

62-
QueuedResource queuedResource = CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
63-
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
64-
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
96+
OperationFuture mockFuture = mock(OperationFuture.class);
97+
when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)))
98+
.thenReturn(mockFuture);
99+
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
65100

66-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getName()).isEqualTo(NODE_NAME);
67-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getNetwork()
68-
.contains(NETWORK_NAME));
69-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getSubnetwork()
70-
.contains(NETWORK_NAME));
101+
verify(mockTpuClient, times(1))
102+
.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class));
103+
}
71104
}
72105
}

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,78 +16,88 @@
1616

1717
package tpu;
1818

19-
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
21-
import static org.junit.Assert.assertNotNull;
19+
import static org.mockito.Mockito.any;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.times;
22+
import static org.mockito.Mockito.verify;
23+
import static org.mockito.Mockito.when;
2224

23-
import com.google.api.gax.rpc.NotFoundException;
25+
import com.google.api.gax.longrunning.OperationFuture;
26+
import com.google.cloud.tpu.v2.CreateNodeRequest;
27+
import com.google.cloud.tpu.v2.DeleteNodeRequest;
2428
import com.google.cloud.tpu.v2.Node;
29+
import com.google.cloud.tpu.v2.NodeName;
30+
import com.google.cloud.tpu.v2.TpuClient;
31+
import com.google.cloud.tpu.v2.TpuSettings;
2532
import java.io.IOException;
26-
import java.util.UUID;
2733
import java.util.concurrent.ExecutionException;
2834
import java.util.concurrent.TimeUnit;
29-
import org.junit.jupiter.api.AfterAll;
30-
import org.junit.jupiter.api.Assertions;
31-
import org.junit.jupiter.api.BeforeAll;
3235
import org.junit.jupiter.api.MethodOrderer;
3336
import org.junit.jupiter.api.Order;
3437
import org.junit.jupiter.api.Test;
3538
import org.junit.jupiter.api.TestMethodOrder;
3639
import org.junit.jupiter.api.Timeout;
3740
import org.junit.runner.RunWith;
38-
import org.junit.runners.JUnit4;
41+
import org.mockito.MockedStatic;
42+
import org.mockito.Mockito;
43+
import org.powermock.modules.junit4.PowerMockRunner;
3944

40-
@RunWith(JUnit4.class)
41-
@Timeout(value = 15, unit = TimeUnit.MINUTES)
45+
@RunWith(PowerMockRunner.class)
46+
@Timeout(value = 3, unit = TimeUnit.MINUTES)
4247
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
4348
public class TpuVmIT {
44-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
49+
private static final String PROJECT_ID = "project-id";
4550
private static final String ZONE = "asia-east1-c";
46-
private static final String NODE_NAME = "test-tpu-" + UUID.randomUUID();
51+
private static final String NODE_NAME = "test-tpu";
4752
private static final String TPU_TYPE = "v2-8";
4853
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
49-
private static final String NODE_PATH_NAME =
50-
String.format("projects/%s/locations/%s/nodes/%s", PROJECT_ID, ZONE, NODE_NAME);
5154

52-
public static void requireEnvVar(String envVarName) {
53-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
54-
.that(System.getenv(envVarName)).isNotEmpty();
55-
}
56-
57-
@BeforeAll
58-
public static void setUp() {
59-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
60-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
61-
}
55+
@Test
56+
@Order(1)
57+
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
58+
TpuClient mockTpuClient = mock(TpuClient.class);
59+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
60+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
61+
.thenReturn(mockTpuClient);
6262

63-
@AfterAll
64-
public static void cleanup() throws Exception {
65-
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
63+
OperationFuture mockFuture = mock(OperationFuture.class);
64+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
65+
.thenReturn(mockFuture);
66+
CreateTpuVm.createTpuVm(PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
6667

67-
// Test that TPUs is deleted
68-
Assertions.assertThrows(
69-
NotFoundException.class,
70-
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
68+
verify(mockTpuClient, times(1)).createNodeAsync(any(CreateNodeRequest.class));
69+
}
7170
}
7271

7372
@Test
74-
@Order(1)
75-
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
73+
public void testGetTpuVm() throws IOException {
74+
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
75+
Node mockNode = mock(Node.class);
76+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
77+
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
78+
when(mock(TpuClient.class).getNode(any(NodeName.class))).thenReturn(mockNode);
7679

77-
Node node = CreateTpuVm.createTpuVm(
78-
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
80+
GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
7981

80-
assertNotNull(node);
81-
assertThat(node.getName().equals(NODE_NAME));
82-
assertThat(node.getAcceleratorType().equals(TPU_TYPE));
82+
// Assertions
83+
verify(mockGetTpuVm, times(1))
84+
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
85+
}
8386
}
8487

8588
@Test
86-
@Order(2)
87-
public void testGetTpuVm() throws IOException {
88-
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
89+
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
90+
TpuClient mockTpuClient = mock(TpuClient.class);
91+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
92+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
93+
.thenReturn(mockTpuClient);
94+
95+
OperationFuture mockFuture = mock(OperationFuture.class);
96+
when(mockTpuClient.deleteNodeAsync(any(DeleteNodeRequest.class)))
97+
.thenReturn(mockFuture);
98+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
8999

90-
assertNotNull(node);
91-
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
100+
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
101+
}
92102
}
93103
}

0 commit comments

Comments
 (0)