Skip to content

Commit 071ad04

Browse files
Fixed tests
1 parent 9e1f106 commit 071ad04

File tree

3 files changed

+177
-89
lines changed

3 files changed

+177
-89
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static org.mockito.Mockito.any;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.mockStatic;
22+
import static org.mockito.Mockito.times;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
25+
26+
import com.google.cloud.tpu.v2.Node;
27+
import com.google.cloud.tpu.v2.NodeName;
28+
import com.google.cloud.tpu.v2.TpuClient;
29+
import java.io.IOException;
30+
import org.junit.jupiter.api.Test;
31+
import org.junit.jupiter.api.Timeout;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
import org.mockito.MockedStatic;
35+
36+
@RunWith(JUnit4.class)
37+
@Timeout(value = 3)
38+
public class GetTpuVmIT {
39+
private static final String PROJECT_ID = "project-id";
40+
private static final String ZONE = "asia-east1-c";
41+
private static final String NODE_NAME = "test-tpu";
42+
43+
@Test
44+
public void testGetTpuVm() throws IOException {
45+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
46+
Node mockNode = mock(Node.class);
47+
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
48+
when(mock(TpuClient.class).getNode(any(NodeName.class))).thenReturn(mockNode);
49+
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
50+
51+
GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
52+
53+
verify(mockGetTpuVm, times(1)).getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
54+
}
55+
}
56+
}

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,99 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
20+
import static org.mockito.Mockito.any;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.mockStatic;
23+
import static org.mockito.Mockito.times;
24+
import static org.mockito.Mockito.verify;
25+
import static org.mockito.Mockito.when;
2126

27+
import com.google.api.gax.longrunning.OperationFuture;
28+
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
29+
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
2230
import com.google.cloud.tpu.v2alpha1.QueuedResource;
23-
import java.util.UUID;
24-
import java.util.concurrent.TimeUnit;
31+
import com.google.cloud.tpu.v2alpha1.QueuedResourceName;
32+
import com.google.cloud.tpu.v2alpha1.TpuClient;
33+
import com.google.cloud.tpu.v2alpha1.TpuSettings;
34+
import java.io.ByteArrayOutputStream;
35+
import java.io.IOException;
36+
import java.io.PrintStream;
37+
import org.junit.Before;
2538
import org.junit.Test;
26-
import org.junit.jupiter.api.AfterAll;
27-
import org.junit.jupiter.api.BeforeAll;
2839
import org.junit.jupiter.api.Timeout;
2940
import org.junit.runner.RunWith;
3041
import org.junit.runners.JUnit4;
42+
import org.mockito.MockedStatic;
3143

3244
@RunWith(JUnit4.class)
33-
@Timeout(value = 6, unit = TimeUnit.MINUTES)
45+
@Timeout(value = 3)
3446
public class QueuedResourceIT {
35-
36-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
47+
private static final String PROJECT_ID = "project-id";
3748
private static final String ZONE = "europe-west4-a";
38-
private static final String NODE_NAME = "test-tpu-queued-resource-network-" + UUID.randomUUID();
49+
private static final String NODE_NAME = "test-tpu";
3950
private static final String TPU_TYPE = "v2-8";
4051
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
41-
private static final String QUEUED_RESOURCE_NAME = "queued-resource-network-" + UUID.randomUUID();
52+
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
4253
private static final String NETWORK_NAME = "default";
54+
private ByteArrayOutputStream bout;
4355

44-
public static void requireEnvVar(String envVarName) {
45-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
46-
.that(System.getenv(envVarName)).isNotEmpty();
56+
@Before
57+
public void setUp() {
58+
bout = new ByteArrayOutputStream();
59+
System.setOut(new PrintStream(bout));
4760
}
4861

49-
@BeforeAll
50-
public static void setUp() {
51-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
52-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
62+
@Test
63+
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
64+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
65+
TpuClient mockTpuClient = mock(TpuClient.class);
66+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
67+
.thenReturn(mockTpuClient);
68+
69+
OperationFuture mockFuture = mock(OperationFuture.class);
70+
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
71+
.thenReturn(mockFuture);
72+
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
73+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
74+
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
75+
76+
verify(mockTpuClient, times(1))
77+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
78+
}
5379
}
5480

55-
@AfterAll
56-
public static void cleanup() {
57-
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
81+
@Test
82+
public void testGetQueuedResource() throws IOException {
83+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
84+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
85+
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
86+
when(mock(TpuClient.class)
87+
.getQueuedResource(any(QueuedResourceName.class))).thenReturn(mockQueuedResource);
88+
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
89+
90+
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
91+
92+
verify(mockGetQueuedResource, times(1))
93+
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
94+
}
5895
}
5996

6097
@Test
61-
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
98+
public void testDeleteTpuVm() {
99+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
100+
TpuClient mockTpuClient = mock(TpuClient.class);
101+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
102+
.thenReturn(mockTpuClient);
62103

63-
QueuedResource queuedResource = CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
64-
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
65-
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
104+
OperationFuture mockFuture = mock(OperationFuture.class);
105+
when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)))
106+
.thenReturn(mockFuture);
107+
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
108+
String output = bout.toString();
66109

67-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getName()).isEqualTo(NODE_NAME);
68-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getNetwork()
69-
.contains(NETWORK_NAME));
70-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getSubnetwork()
71-
.contains(NETWORK_NAME));
110+
assertThat(output).contains("Deleted Queued Resource:");
111+
verify(mockTpuClient, times(1))
112+
.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class));
113+
}
72114
}
73115
}

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 50 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,91 +16,81 @@
1616

1717
package tpu;
1818

19-
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
21-
import static org.junit.Assert.assertNotNull;
19+
import static org.mockito.Mockito.any;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.mockStatic;
22+
import static org.mockito.Mockito.times;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
2225

23-
import com.google.api.gax.rpc.NotFoundException;
24-
import com.google.cloud.tpu.v2.Node;
26+
import com.google.api.gax.longrunning.OperationFuture;
27+
import com.google.cloud.tpu.v2.CreateNodeRequest;
28+
import com.google.cloud.tpu.v2.DeleteNodeRequest;
29+
import com.google.cloud.tpu.v2.LocationName;
2530
import com.google.cloud.tpu.v2.TpuClient;
31+
import com.google.cloud.tpu.v2.TpuSettings;
2632
import java.io.IOException;
27-
import java.util.UUID;
2833
import java.util.concurrent.ExecutionException;
29-
import java.util.concurrent.TimeUnit;
30-
import org.junit.Assert;
31-
import org.junit.jupiter.api.AfterAll;
32-
import org.junit.jupiter.api.Assertions;
33-
import org.junit.jupiter.api.BeforeAll;
34-
import org.junit.jupiter.api.MethodOrderer;
35-
import org.junit.jupiter.api.Order;
36-
import org.junit.jupiter.api.Test;
37-
import org.junit.jupiter.api.TestMethodOrder;
34+
import org.junit.Test;
3835
import org.junit.jupiter.api.Timeout;
3936
import org.junit.runner.RunWith;
4037
import org.junit.runners.JUnit4;
38+
import org.mockito.MockedStatic;
4139

4240
@RunWith(JUnit4.class)
43-
@Timeout(value = 15, unit = TimeUnit.MINUTES)
44-
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
41+
@Timeout(value = 3)
4542
public class TpuVmIT {
46-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
43+
private static final String PROJECT_ID = "project-id";
4744
private static final String ZONE = "asia-east1-c";
48-
private static final String NODE_NAME = "test-tpu-" + UUID.randomUUID();
45+
private static final String NODE_NAME = "test-tpu";
4946
private static final String TPU_TYPE = "v2-8";
5047
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
51-
private static final String NODE_PATH_NAME =
52-
String.format("projects/%s/locations/%s/nodes/%s", PROJECT_ID, ZONE, NODE_NAME);
53-
54-
public static void requireEnvVar(String envVarName) {
55-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
56-
.that(System.getenv(envVarName)).isNotEmpty();
57-
}
58-
59-
@BeforeAll
60-
public static void setUp() {
61-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
62-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
63-
}
64-
65-
@AfterAll
66-
public static void cleanup() throws Exception {
67-
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
68-
69-
// Test that TPUs is deleted
70-
Assertions.assertThrows(
71-
NotFoundException.class,
72-
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
73-
}
7448

7549
@Test
76-
@Order(1)
77-
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
50+
public void testCreateTpuVm() throws Exception {
51+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
52+
TpuClient mockTpuClient = mock(TpuClient.class);
53+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
54+
.thenReturn(mockTpuClient);
7855

79-
Node node = CreateTpuVm.createTpuVm(
80-
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
56+
OperationFuture mockFuture = mock(OperationFuture.class);
57+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
58+
.thenReturn(mockFuture);
59+
CreateTpuVm.createTpuVm(
60+
PROJECT_ID, ZONE, NODE_NAME,
61+
TPU_TYPE, TPU_SOFTWARE_VERSION);
8162

82-
assertNotNull(node);
83-
assertThat(node.getName().equals(NODE_NAME));
84-
assertThat(node.getAcceleratorType().equals(TPU_TYPE));
63+
verify(mockTpuClient, times(1)).createNodeAsync(any(CreateNodeRequest.class));
64+
}
8565
}
8666

8767
@Test
88-
@Order(2)
89-
public void testGetTpuVm() throws IOException {
90-
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
68+
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
69+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
70+
TpuClient mockTpuClient = mock(TpuClient.class);
71+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
72+
.thenReturn(mockTpuClient);
73+
74+
OperationFuture mockFuture = mock(OperationFuture.class);
75+
when(mockTpuClient.deleteNodeAsync(any(DeleteNodeRequest.class)))
76+
.thenReturn(mockFuture);
77+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
9178

92-
assertNotNull(node);
93-
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
79+
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
80+
}
9481
}
9582

9683
@Test
97-
@Order(2)
9884
public void testListTpuVm() throws IOException {
99-
TpuClient.ListNodesPagedResponse nodesList = ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
85+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
86+
TpuClient.ListNodesPagedResponse mockListNodes = mock(TpuClient.ListNodesPagedResponse.class);
87+
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
88+
when(mock(TpuClient.class).listNodes(any(LocationName.class))).thenReturn(mockListNodes);
89+
ListTpuVms mockListTpuVms = mock(ListTpuVms.class);
90+
91+
ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
10092

101-
assertNotNull(nodesList);
102-
for (Node node : nodesList.iterateAll()) {
103-
Assert.assertTrue(node.getName().contains("test-tpu"));
93+
verify(mockListTpuVms, times(1)).listTpuVms(PROJECT_ID, ZONE);
10494
}
10595
}
106-
}
96+
}

0 commit comments

Comments
 (0)