Skip to content

Commit db9254f

Browse files
Fixed tests
1 parent 49a309e commit db9254f

File tree

3 files changed

+56
-20
lines changed

3 files changed

+56
-20
lines changed

tpu/src/test/java/tpu/CreateTpuIT.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package tpu;
1818

19+
import static org.junit.Assert.assertEquals;
1920
import static org.mockito.Mockito.any;
2021
import static org.mockito.Mockito.mock;
2122
import static org.mockito.Mockito.mockStatic;
@@ -25,6 +26,7 @@
2526

2627
import com.google.api.gax.longrunning.OperationFuture;
2728
import com.google.cloud.tpu.v2.CreateNodeRequest;
29+
import com.google.cloud.tpu.v2.Node;
2830
import com.google.cloud.tpu.v2.TpuClient;
2931
import com.google.cloud.tpu.v2.TpuSettings;
3032
import org.junit.Test;
@@ -45,19 +47,24 @@ public class CreateTpuIT {
4547
@Test
4648
public void testCreateTpuVm() throws Exception {
4749
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
50+
Node mockNode = mock(Node.class);
4851
TpuClient mockTpuClient = mock(TpuClient.class);
52+
OperationFuture mockFuture = mock(OperationFuture.class);
53+
4954
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
5055
.thenReturn(mockTpuClient);
51-
52-
OperationFuture mockFuture = mock(OperationFuture.class);
5356
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
5457
.thenReturn(mockFuture);
55-
CreateTpuVm.createTpuVm(
58+
when(mockFuture.get()).thenReturn(mockNode);
59+
60+
Node returnedNode = CreateTpuVm.createTpuVm(
5661
PROJECT_ID, ZONE, NODE_NAME,
5762
TPU_TYPE, TPU_SOFTWARE_VERSION);
5863

5964
verify(mockTpuClient, times(1))
6065
.createNodeAsync(any(CreateNodeRequest.class));
66+
verify(mockFuture, times(1)).get();
67+
assertEquals(returnedNode, mockNode);
6168
}
6269
}
6370
}

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertEquals;
2021
import static org.mockito.Mockito.any;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.mockStatic;
@@ -27,8 +28,8 @@
2728
import com.google.api.gax.longrunning.OperationFuture;
2829
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
2930
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
31+
import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
3032
import com.google.cloud.tpu.v2alpha1.QueuedResource;
31-
import com.google.cloud.tpu.v2alpha1.QueuedResourceName;
3233
import com.google.cloud.tpu.v2alpha1.TpuClient;
3334
import com.google.cloud.tpu.v2alpha1.TpuSettings;
3435
import java.io.ByteArrayOutputStream;
@@ -62,48 +63,59 @@ public void setUp() {
6263
@Test
6364
public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
6465
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
66+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
6567
TpuClient mockTpuClient = mock(TpuClient.class);
68+
OperationFuture mockFuture = mock(OperationFuture.class);
69+
6670
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
6771
.thenReturn(mockTpuClient);
68-
69-
OperationFuture mockFuture = mock(OperationFuture.class);
7072
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
7173
.thenReturn(mockFuture);
72-
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
74+
when(mockFuture.get()).thenReturn(mockQueuedResource);
75+
76+
QueuedResource returnedQueuedResource =
77+
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
7378
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
7479
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
7580

7681
verify(mockTpuClient, times(1))
7782
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
83+
verify(mockFuture, times(1)).get();
84+
assertEquals(returnedQueuedResource, mockQueuedResource);
7885
}
7986
}
8087

8188
@Test
8289
public void testGetQueuedResource() throws IOException {
8390
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
84-
QueuedResource mockQueuedResource = mock(QueuedResource.class);
85-
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
86-
when(mock(TpuClient.class)
87-
.getQueuedResource(any(QueuedResourceName.class))).thenReturn(mockQueuedResource);
91+
TpuClient mockClient = mock(TpuClient.class);
8892
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
93+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
94+
95+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
96+
when(mockClient.getQueuedResource(any(GetQueuedResourceRequest.class)))
97+
.thenReturn(mockQueuedResource);
8998

90-
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
99+
QueuedResource returnedQueuedResource =
100+
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
91101

92102
verify(mockGetQueuedResource, times(1))
93103
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
104+
assertEquals(returnedQueuedResource, mockQueuedResource);
94105
}
95106
}
96107

97108
@Test
98109
public void testDeleteTpuVm() {
99110
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
100111
TpuClient mockTpuClient = mock(TpuClient.class);
112+
OperationFuture mockFuture = mock(OperationFuture.class);
113+
101114
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
102115
.thenReturn(mockTpuClient);
103-
104-
OperationFuture mockFuture = mock(OperationFuture.class);
105116
when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)))
106117
.thenReturn(mockFuture);
118+
107119
DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME);
108120
String output = bout.toString();
109121

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package tpu;
1818

19+
import static com.google.common.truth.Truth.assertThat;
1920
import static org.mockito.Mockito.any;
2021
import static org.mockito.Mockito.mock;
2122
import static org.mockito.Mockito.mockStatic;
@@ -25,12 +26,15 @@
2526

2627
import com.google.api.gax.longrunning.OperationFuture;
2728
import com.google.cloud.tpu.v2.DeleteNodeRequest;
29+
import com.google.cloud.tpu.v2.GetNodeRequest;
2830
import com.google.cloud.tpu.v2.Node;
29-
import com.google.cloud.tpu.v2.NodeName;
3031
import com.google.cloud.tpu.v2.TpuClient;
3132
import com.google.cloud.tpu.v2.TpuSettings;
33+
import java.io.ByteArrayOutputStream;
3234
import java.io.IOException;
35+
import java.io.PrintStream;
3336
import java.util.concurrent.ExecutionException;
37+
import org.junit.jupiter.api.BeforeAll;
3438
import org.junit.jupiter.api.Test;
3539
import org.junit.jupiter.api.Timeout;
3640
import org.junit.runner.RunWith;
@@ -43,34 +47,47 @@ public class TpuVmIT {
4347
private static final String PROJECT_ID = "project-id";
4448
private static final String ZONE = "asia-east1-c";
4549
private static final String NODE_NAME = "test-tpu";
50+
private static ByteArrayOutputStream bout;
51+
52+
@BeforeAll
53+
public static void setUp() {
54+
bout = new ByteArrayOutputStream();
55+
System.setOut(new PrintStream(bout));
56+
}
4657

4758
@Test
4859
public void testGetTpuVm() throws IOException {
4960
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
5061
Node mockNode = mock(Node.class);
51-
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
52-
when(mock(TpuClient.class).getNode(any(NodeName.class))).thenReturn(mockNode);
62+
TpuClient mockClient = mock(TpuClient.class);
5363
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
5464

55-
GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
65+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
66+
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
67+
68+
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
5669

5770
verify(mockGetTpuVm, times(1))
5871
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
72+
assertThat(returnedNode).isEqualTo(mockNode);
5973
}
6074
}
6175

6276
@Test
6377
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
6478
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
6579
TpuClient mockTpuClient = mock(TpuClient.class);
80+
OperationFuture mockFuture = mock(OperationFuture.class);
81+
6682
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
6783
.thenReturn(mockTpuClient);
68-
69-
OperationFuture mockFuture = mock(OperationFuture.class);
7084
when(mockTpuClient.deleteNodeAsync(any(DeleteNodeRequest.class)))
7185
.thenReturn(mockFuture);
86+
7287
DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME);
88+
String output = bout.toString();
7389

90+
assertThat(output).contains("TPU VM deleted");
7491
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
7592
}
7693
}

0 commit comments

Comments
 (0)