Skip to content

Commit 372f0b5

Browse files
fix(tpu): improving tpu tests (#9679)
* Changed tests * Created separated test for CreateTpuVm * Renamed file * Fixed imports * Fixed tests as requested in comments * Deleted redundant dependency * Fixed indentation * Fixed tests
1 parent 99e983e commit 372f0b5

File tree

3 files changed

+201
-75
lines changed

3 files changed

+201
-75
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static org.junit.Assert.assertEquals;
20+
import static org.mockito.Mockito.any;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.mockStatic;
23+
import static org.mockito.Mockito.times;
24+
import static org.mockito.Mockito.verify;
25+
import static org.mockito.Mockito.when;
26+
27+
import com.google.api.gax.longrunning.OperationFuture;
28+
import com.google.cloud.tpu.v2.CreateNodeRequest;
29+
import com.google.cloud.tpu.v2.Node;
30+
import com.google.cloud.tpu.v2.TpuClient;
31+
import com.google.cloud.tpu.v2.TpuSettings;
32+
import org.junit.Test;
33+
import org.junit.jupiter.api.Timeout;
34+
import org.junit.runner.RunWith;
35+
import org.junit.runners.JUnit4;
36+
import org.mockito.MockedStatic;
37+
38+
@RunWith(JUnit4.class)
39+
@Timeout(value = 3)
40+
public class CreateTpuIT {
41+
private static final String PROJECT_ID = "project-id";
42+
private static final String ZONE = "asia-east1-c";
43+
private static final String NODE_NAME = "test-tpu";
44+
private static final String TPU_TYPE = "v2-8";
45+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
46+
47+
@Test
48+
public void testCreateTpuVm() throws Exception {
49+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
50+
Node mockNode = mock(Node.class);
51+
TpuClient mockTpuClient = mock(TpuClient.class);
52+
OperationFuture mockFuture = mock(OperationFuture.class);
53+
54+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
55+
.thenReturn(mockTpuClient);
56+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
57+
.thenReturn(mockFuture);
58+
when(mockFuture.get()).thenReturn(mockNode);
59+
60+
Node returnedNode = CreateTpuVm.createTpuVm(
61+
PROJECT_ID, ZONE, NODE_NAME,
62+
TPU_TYPE, TPU_SOFTWARE_VERSION);
63+
64+
verify(mockTpuClient, times(1))
65+
.createNodeAsync(any(CreateNodeRequest.class));
66+
verify(mockFuture, times(1)).get();
67+
assertEquals(returnedNode, mockNode);
68+
}
69+
}
70+
}

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,56 +17,111 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
20+
import static org.junit.Assert.assertEquals;
21+
import static org.mockito.Mockito.any;
22+
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.mockStatic;
24+
import static org.mockito.Mockito.times;
25+
import static org.mockito.Mockito.verify;
26+
import static org.mockito.Mockito.when;
2127

28+
import com.google.api.gax.longrunning.OperationFuture;
29+
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
30+
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
31+
import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
2232
import com.google.cloud.tpu.v2alpha1.QueuedResource;
23-
import java.util.UUID;
24-
import java.util.concurrent.TimeUnit;
33+
import com.google.cloud.tpu.v2alpha1.TpuClient;
34+
import com.google.cloud.tpu.v2alpha1.TpuSettings;
35+
import java.io.ByteArrayOutputStream;
36+
import java.io.IOException;
37+
import java.io.PrintStream;
38+
import org.junit.Before;
2539
import org.junit.Test;
26-
import org.junit.jupiter.api.AfterAll;
27-
import org.junit.jupiter.api.BeforeAll;
2840
import org.junit.jupiter.api.Timeout;
2941
import org.junit.runner.RunWith;
3042
import org.junit.runners.JUnit4;
43+
import org.mockito.MockedStatic;
3144

3245
@RunWith(JUnit4.class)
33-
@Timeout(value = 6, unit = TimeUnit.MINUTES)
46+
@Timeout(value = 3)
3447
public class QueuedResourceIT {
35-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
48+
private static final String PROJECT_ID = "project-id";
3649
private static final String ZONE = "europe-west4-a";
37-
private static final String NODE_NAME = "test-tpu-queued-resource-network-" + UUID.randomUUID();
50+
private static final String NODE_NAME = "test-tpu";
3851
private static final String TPU_TYPE = "v2-8";
3952
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
40-
private static final String QUEUED_RESOURCE_NAME = "queued-resource-network-" + UUID.randomUUID();
53+
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
4154
private static final String NETWORK_NAME = "default";
55+
private ByteArrayOutputStream bout;
4256

43-
public static void requireEnvVar(String envVarName) {
44-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
45-
.that(System.getenv(envVarName)).isNotEmpty();
57+
@Before
58+
public void setUp() {
59+
bout = new ByteArrayOutputStream();
60+
System.setOut(new PrintStream(bout));
4661
}
4762

48-
@BeforeAll
49-
public static void setUp() {
50-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
51-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
63+
@Test
64+
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
65+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
66+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
67+
TpuClient mockTpuClient = mock(TpuClient.class);
68+
OperationFuture mockFuture = mock(OperationFuture.class);
69+
70+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
71+
.thenReturn(mockTpuClient);
72+
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
73+
.thenReturn(mockFuture);
74+
when(mockFuture.get()).thenReturn(mockQueuedResource);
75+
76+
QueuedResource returnedQueuedResource =
77+
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
78+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
79+
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
80+
81+
verify(mockTpuClient, times(1))
82+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
83+
verify(mockFuture, times(1)).get();
84+
assertEquals(returnedQueuedResource, mockQueuedResource);
85+
}
5286
}
5387

54-
@AfterAll
55-
public static void cleanup() {
56-
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
88+
@Test
89+
public void testGetQueuedResource() throws IOException {
90+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
91+
TpuClient mockClient = mock(TpuClient.class);
92+
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
93+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
94+
95+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
96+
when(mockClient.getQueuedResource(any(GetQueuedResourceRequest.class)))
97+
.thenReturn(mockQueuedResource);
98+
99+
QueuedResource returnedQueuedResource =
100+
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
101+
102+
verify(mockGetQueuedResource, times(1))
103+
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
104+
assertEquals(returnedQueuedResource, mockQueuedResource);
105+
}
57106
}
58107

59108
@Test
60-
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
109+
public void testDeleteTpuVm() {
110+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
111+
TpuClient mockTpuClient = mock(TpuClient.class);
112+
OperationFuture mockFuture = mock(OperationFuture.class);
113+
114+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
115+
.thenReturn(mockTpuClient);
116+
when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)))
117+
.thenReturn(mockFuture);
61118

62-
QueuedResource queuedResource = CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
63-
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
64-
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
119+
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
120+
String output = bout.toString();
65121

66-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getName()).isEqualTo(NODE_NAME);
67-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getNetwork()
68-
.contains(NETWORK_NAME));
69-
assertThat(queuedResource.getTpu().getNodeSpec(0).getNode().getNetworkConfig().getSubnetwork()
70-
.contains(NETWORK_NAME));
122+
assertThat(output).contains("Deleted Queued Resource:");
123+
verify(mockTpuClient, times(1))
124+
.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class));
125+
}
71126
}
72127
}

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,77 +17,78 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static com.google.common.truth.Truth.assertWithMessage;
21-
import static org.junit.Assert.assertNotNull;
20+
import static org.mockito.Mockito.any;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.mockStatic;
23+
import static org.mockito.Mockito.times;
24+
import static org.mockito.Mockito.verify;
25+
import static org.mockito.Mockito.when;
2226

23-
import com.google.api.gax.rpc.NotFoundException;
27+
import com.google.api.gax.longrunning.OperationFuture;
28+
import com.google.cloud.tpu.v2.DeleteNodeRequest;
29+
import com.google.cloud.tpu.v2.GetNodeRequest;
2430
import com.google.cloud.tpu.v2.Node;
31+
import com.google.cloud.tpu.v2.TpuClient;
32+
import com.google.cloud.tpu.v2.TpuSettings;
33+
import java.io.ByteArrayOutputStream;
2534
import java.io.IOException;
26-
import java.util.UUID;
35+
import java.io.PrintStream;
2736
import java.util.concurrent.ExecutionException;
28-
import java.util.concurrent.TimeUnit;
29-
import org.junit.jupiter.api.AfterAll;
30-
import org.junit.jupiter.api.Assertions;
3137
import org.junit.jupiter.api.BeforeAll;
32-
import org.junit.jupiter.api.MethodOrderer;
33-
import org.junit.jupiter.api.Order;
3438
import org.junit.jupiter.api.Test;
35-
import org.junit.jupiter.api.TestMethodOrder;
3639
import org.junit.jupiter.api.Timeout;
3740
import org.junit.runner.RunWith;
3841
import org.junit.runners.JUnit4;
42+
import org.mockito.MockedStatic;
3943

4044
@RunWith(JUnit4.class)
41-
@Timeout(value = 15, unit = TimeUnit.MINUTES)
42-
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
45+
@Timeout(value = 3)
4346
public class TpuVmIT {
44-
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
47+
private static final String PROJECT_ID = "project-id";
4548
private static final String ZONE = "asia-east1-c";
46-
private static final String NODE_NAME = "test-tpu-" + UUID.randomUUID();
47-
private static final String TPU_TYPE = "v2-8";
48-
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
49-
private static final String NODE_PATH_NAME =
50-
String.format("projects/%s/locations/%s/nodes/%s", PROJECT_ID, ZONE, NODE_NAME);
51-
52-
public static void requireEnvVar(String envVarName) {
53-
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
54-
.that(System.getenv(envVarName)).isNotEmpty();
55-
}
49+
private static final String NODE_NAME = "test-tpu";
50+
private static ByteArrayOutputStream bout;
5651

5752
@BeforeAll
5853
public static void setUp() {
59-
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
60-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
61-
}
62-
63-
@AfterAll
64-
public static void cleanup() throws Exception {
65-
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
66-
67-
// Test that TPUs is deleted
68-
Assertions.assertThrows(
69-
NotFoundException.class,
70-
() -> GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME));
54+
bout = new ByteArrayOutputStream();
55+
System.setOut(new PrintStream(bout));
7156
}
7257

7358
@Test
74-
@Order(1)
75-
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
59+
public void testGetTpuVm() throws IOException {
60+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
61+
Node mockNode = mock(Node.class);
62+
TpuClient mockClient = mock(TpuClient.class);
63+
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
64+
65+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
66+
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
7667

77-
Node node = CreateTpuVm.createTpuVm(
78-
PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
68+
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
7969

80-
assertNotNull(node);
81-
assertThat(node.getName().equals(NODE_NAME));
82-
assertThat(node.getAcceleratorType().equals(TPU_TYPE));
70+
verify(mockGetTpuVm, times(1))
71+
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
72+
assertThat(returnedNode).isEqualTo(mockNode);
73+
}
8374
}
8475

8576
@Test
86-
@Order(2)
87-
public void testGetTpuVm() throws IOException {
88-
Node node = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
77+
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
78+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
79+
TpuClient mockTpuClient = mock(TpuClient.class);
80+
OperationFuture mockFuture = mock(OperationFuture.class);
81+
82+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
83+
.thenReturn(mockTpuClient);
84+
when(mockTpuClient.deleteNodeAsync(any(DeleteNodeRequest.class)))
85+
.thenReturn(mockFuture);
86+
87+
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
88+
String output = bout.toString();
8989

90-
assertNotNull(node);
91-
assertThat(node.getName()).isEqualTo(NODE_PATH_NAME);
90+
assertThat(output).contains("TPU VM deleted");
91+
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
92+
}
9293
}
9394
}

0 commit comments

Comments
 (0)