Skip to content

Commit 428a2ef

Browse files
Merged changes from main
2 parents 1e1b840 + 372f0b5 commit 428a2ef

File tree

3 files changed

+124
-29
lines changed

3 files changed

+124
-29
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static org.junit.Assert.assertEquals;
20+
import static org.mockito.Mockito.any;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.mockStatic;
23+
import static org.mockito.Mockito.times;
24+
import static org.mockito.Mockito.verify;
25+
import static org.mockito.Mockito.when;
26+
27+
import com.google.api.gax.longrunning.OperationFuture;
28+
import com.google.cloud.tpu.v2.CreateNodeRequest;
29+
import com.google.cloud.tpu.v2.Node;
30+
import com.google.cloud.tpu.v2.TpuClient;
31+
import com.google.cloud.tpu.v2.TpuSettings;
32+
import org.junit.Test;
33+
import org.junit.jupiter.api.Timeout;
34+
import org.junit.runner.RunWith;
35+
import org.junit.runners.JUnit4;
36+
import org.mockito.MockedStatic;
37+
38+
@RunWith(JUnit4.class)
39+
@Timeout(value = 3)
40+
public class CreateTpuIT {
41+
private static final String PROJECT_ID = "project-id";
42+
private static final String ZONE = "asia-east1-c";
43+
private static final String NODE_NAME = "test-tpu";
44+
private static final String TPU_TYPE = "v2-8";
45+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
46+
47+
@Test
48+
public void testCreateTpuVm() throws Exception {
49+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
50+
Node mockNode = mock(Node.class);
51+
TpuClient mockTpuClient = mock(TpuClient.class);
52+
OperationFuture mockFuture = mock(OperationFuture.class);
53+
54+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
55+
.thenReturn(mockTpuClient);
56+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
57+
.thenReturn(mockFuture);
58+
when(mockFuture.get()).thenReturn(mockNode);
59+
60+
Node returnedNode = CreateTpuVm.createTpuVm(
61+
PROJECT_ID, ZONE, NODE_NAME,
62+
TPU_TYPE, TPU_SOFTWARE_VERSION);
63+
64+
verify(mockTpuClient, times(1))
65+
.createNodeAsync(any(CreateNodeRequest.class));
66+
verify(mockFuture, times(1)).get();
67+
assertEquals(returnedNode, mockNode);
68+
}
69+
}
70+
}

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertEquals;
2021
import static org.mockito.Mockito.any;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.mockStatic;
@@ -27,8 +28,8 @@
2728
import com.google.api.gax.longrunning.OperationFuture;
2829
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
2930
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
31+
import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
3032
import com.google.cloud.tpu.v2alpha1.QueuedResource;
31-
import com.google.cloud.tpu.v2alpha1.QueuedResourceName;
3233
import com.google.cloud.tpu.v2alpha1.TpuClient;
3334
import com.google.cloud.tpu.v2alpha1.TpuSettings;
3435
import java.io.ByteArrayOutputStream;
@@ -62,48 +63,59 @@ public void setUp() {
6263
@Test
6364
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
6465
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
66+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
6567
TpuClient mockTpuClient = mock(TpuClient.class);
68+
OperationFuture mockFuture = mock(OperationFuture.class);
69+
6670
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
6771
.thenReturn(mockTpuClient);
68-
69-
OperationFuture mockFuture = mock(OperationFuture.class);
7072
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
7173
.thenReturn(mockFuture);
72-
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
74+
when(mockFuture.get()).thenReturn(mockQueuedResource);
75+
76+
QueuedResource returnedQueuedResource =
77+
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
7378
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
7479
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
7580

7681
verify(mockTpuClient, times(1))
7782
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
83+
verify(mockFuture, times(1)).get();
84+
assertEquals(returnedQueuedResource, mockQueuedResource);
7885
}
7986
}
8087

8188
@Test
8289
public void testGetQueuedResource() throws IOException {
8390
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
84-
QueuedResource mockQueuedResource = mock(QueuedResource.class);
85-
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
86-
when(mock(TpuClient.class)
87-
.getQueuedResource(any(QueuedResourceName.class))).thenReturn(mockQueuedResource);
91+
TpuClient mockClient = mock(TpuClient.class);
8892
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
93+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
94+
95+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
96+
when(mockClient.getQueuedResource(any(GetQueuedResourceRequest.class)))
97+
.thenReturn(mockQueuedResource);
8998

90-
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
99+
QueuedResource returnedQueuedResource =
100+
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
91101

92102
verify(mockGetQueuedResource, times(1))
93103
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
104+
assertEquals(returnedQueuedResource, mockQueuedResource);
94105
}
95106
}
96107

97108
@Test
98109
public void testDeleteTpuVm() {
99110
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
100111
TpuClient mockTpuClient = mock(TpuClient.class);
112+
OperationFuture mockFuture = mock(OperationFuture.class);
113+
101114
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
102115
.thenReturn(mockTpuClient);
103-
104-
OperationFuture mockFuture = mock(OperationFuture.class);
105116
when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)))
106117
.thenReturn(mockFuture);
118+
107119
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
108120
String output = bout.toString();
109121

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package tpu;
1818

19+
import static com.google.common.truth.Truth.assertThat;
1920
import static org.mockito.Mockito.any;
2021
import static org.mockito.Mockito.mock;
2122
import static org.mockito.Mockito.mockStatic;
@@ -24,14 +25,18 @@
2425
import static org.mockito.Mockito.when;
2526

2627
import com.google.api.gax.longrunning.OperationFuture;
27-
import com.google.cloud.tpu.v2.CreateNodeRequest;
2828
import com.google.cloud.tpu.v2.DeleteNodeRequest;
29+
import com.google.cloud.tpu.v2.GetNodeRequest;
2930
import com.google.cloud.tpu.v2.LocationName;
31+
import com.google.cloud.tpu.v2.Node;
3032
import com.google.cloud.tpu.v2.TpuClient;
3133
import com.google.cloud.tpu.v2.TpuSettings;
34+
import java.io.ByteArrayOutputStream;
3235
import java.io.IOException;
36+
import java.io.PrintStream;
3337
import java.util.concurrent.ExecutionException;
34-
import org.junit.Test;
38+
import org.junit.jupiter.api.BeforeAll;
39+
import org.junit.jupiter.api.Test;
3540
import org.junit.jupiter.api.Timeout;
3641
import org.junit.runner.RunWith;
3742
import org.junit.runners.JUnit4;
@@ -43,39 +48,47 @@ public class TpuVmIT {
4348
private static final String PROJECT_ID = "project-id";
4449
private static final String ZONE = "asia-east1-c";
4550
private static final String NODE_NAME = "test-tpu";
46-
private static final String TPU_TYPE = "v2-8";
47-
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
51+
private static ByteArrayOutputStream bout;
52+
53+
@BeforeAll
54+
public static void setUp() {
55+
bout = new ByteArrayOutputStream();
56+
System.setOut(new PrintStream(bout));
57+
}
4858

4959
@Test
50-
public void testCreateTpuVm() throws Exception {
60+
public void testGetTpuVm() throws IOException {
5161
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
52-
TpuClient mockTpuClient = mock(TpuClient.class);
53-
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
54-
.thenReturn(mockTpuClient);
62+
Node mockNode = mock(Node.class);
63+
TpuClient mockClient = mock(TpuClient.class);
64+
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
5565

56-
OperationFuture mockFuture = mock(OperationFuture.class);
57-
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
58-
.thenReturn(mockFuture);
59-
CreateTpuVm.createTpuVm(
60-
PROJECT_ID, ZONE, NODE_NAME,
61-
TPU_TYPE, TPU_SOFTWARE_VERSION);
66+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
67+
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
6268

63-
verify(mockTpuClient, times(1)).createNodeAsync(any(CreateNodeRequest.class));
69+
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
70+
71+
verify(mockGetTpuVm, times(1))
72+
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
73+
assertThat(returnedNode).isEqualTo(mockNode);
6474
}
6575
}
6676

6777
@Test
6878
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
6979
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
7080
TpuClient mockTpuClient = mock(TpuClient.class);
81+
OperationFuture mockFuture = mock(OperationFuture.class);
82+
7183
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
7284
.thenReturn(mockTpuClient);
73-
74-
OperationFuture mockFuture = mock(OperationFuture.class);
7585
when(mockTpuClient.deleteNodeAsync(any(DeleteNodeRequest.class)))
7686
.thenReturn(mockFuture);
87+
7788
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
89+
String output = bout.toString();
7890

91+
assertThat(output).contains("TPU VM deleted");
7992
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
8093
}
8194
}
@@ -93,4 +106,4 @@ public void testListTpuVm() throws IOException {
93106
verify(mockListTpuVms, times(1)).listTpuVms(PROJECT_ID, ZONE);
94107
}
95108
}
96-
}
109+
}

0 commit comments

Comments
 (0)