Skip to content

Commit 32b4501

Browse files
Fixed test
1 parent 16f9356 commit 32b4501

File tree

4 files changed

+63
-147
lines changed

4 files changed

+63
-147
lines changed

tpu/src/test/java/tpu/CreateTpuIT.java

Lines changed: 0 additions & 70 deletions
This file was deleted.

tpu/src/test/java/tpu/ListTpuVmsIT.java

Lines changed: 0 additions & 68 deletions
This file was deleted.

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
import java.io.ByteArrayOutputStream;
3636
import java.io.IOException;
3737
import java.io.PrintStream;
38-
import org.junit.Before;
39-
import org.junit.Test;
38+
import org.junit.jupiter.api.BeforeAll;
39+
import org.junit.jupiter.api.Test;
4040
import org.junit.jupiter.api.Timeout;
4141
import org.junit.runner.RunWith;
4242
import org.junit.runners.JUnit4;
@@ -52,10 +52,10 @@ public class QueuedResourceIT {
5252
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
5353
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
5454
private static final String NETWORK_NAME = "default";
55-
private ByteArrayOutputStream bout;
55+
private static ByteArrayOutputStream bout;
5656

57-
@Before
58-
public void setUp() {
57+
@BeforeAll
58+
public static void setUp() {
5959
bout = new ByteArrayOutputStream();
6060
System.setOut(new PrintStream(bout));
6161
}
@@ -106,7 +106,7 @@ public void testGetQueuedResource() throws IOException {
106106
}
107107

108108
@Test
109-
public void testDeleteTpuVm() {
109+
public void testDeleteForceQueuedResource() {
110110
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
111111
TpuClient mockTpuClient = mock(TpuClient.class);
112112
OperationFuture mockFuture = mock(OperationFuture.class);

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertEquals;
2021
import static org.mockito.Mockito.any;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.mockStatic;
@@ -25,14 +26,18 @@
2526
import static org.mockito.Mockito.when;
2627

2728
import com.google.api.gax.longrunning.OperationFuture;
29+
import com.google.cloud.tpu.v2.CreateNodeRequest;
2830
import com.google.cloud.tpu.v2.DeleteNodeRequest;
2931
import com.google.cloud.tpu.v2.GetNodeRequest;
32+
import com.google.cloud.tpu.v2.ListNodesRequest;
3033
import com.google.cloud.tpu.v2.Node;
3134
import com.google.cloud.tpu.v2.TpuClient;
3235
import com.google.cloud.tpu.v2.TpuSettings;
3336
import java.io.ByteArrayOutputStream;
3437
import java.io.IOException;
3538
import java.io.PrintStream;
39+
import java.util.Arrays;
40+
import java.util.List;
3641
import java.util.concurrent.ExecutionException;
3742
import org.junit.jupiter.api.BeforeAll;
3843
import org.junit.jupiter.api.Test;
@@ -47,6 +52,8 @@ public class TpuVmIT {
4752
private static final String PROJECT_ID = "project-id";
4853
private static final String ZONE = "asia-east1-c";
4954
private static final String NODE_NAME = "test-tpu";
55+
private static final String TPU_TYPE = "v2-8";
56+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
5057
private static ByteArrayOutputStream bout;
5158

5259
@BeforeAll
@@ -55,21 +62,45 @@ public static void setUp() {
5562
System.setOut(new PrintStream(bout));
5663
}
5764

65+
@Test
66+
public void testCreateTpuVm() throws Exception {
67+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
68+
Node mockNode = mock(Node.class);
69+
TpuClient mockTpuClient = mock(TpuClient.class);
70+
OperationFuture mockFuture = mock(OperationFuture.class);
71+
72+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
73+
.thenReturn(mockTpuClient);
74+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
75+
.thenReturn(mockFuture);
76+
when(mockFuture.get()).thenReturn(mockNode);
77+
78+
Node returnedNode = CreateTpuVm.createTpuVm(
79+
PROJECT_ID, ZONE, NODE_NAME,
80+
TPU_TYPE, TPU_SOFTWARE_VERSION);
81+
82+
verify(mockTpuClient, times(1))
83+
.createNodeAsync(any(CreateNodeRequest.class));
84+
verify(mockFuture, times(1)).get();
85+
assertEquals(returnedNode, mockNode);
86+
}
87+
}
88+
5889
@Test
5990
public void testGetTpuVm() throws IOException {
6091
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
6192
Node mockNode = mock(Node.class);
6293
TpuClient mockClient = mock(TpuClient.class);
63-
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
6494

6595
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
6696
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
6797

6898
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
6999

70-
verify(mockGetTpuVm, times(1))
71-
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
100+
verify(mockClient, times(1))
101+
.getNode(any(GetNodeRequest.class));
72102
assertThat(returnedNode).isEqualTo(mockNode);
103+
verify(mockClient, times(1)).close();
73104
}
74105
}
75106

@@ -91,4 +122,27 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte
91122
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
92123
}
93124
}
125+
126+
@Test
127+
public void testListTpuVm() throws IOException {
128+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
129+
Node mockNode1 = mock(Node.class);
130+
Node mockNode2 = mock(Node.class);
131+
List<Node> mockListNodes = Arrays.asList(mockNode1, mockNode2);
132+
133+
TpuClient mockTpuClient = mock(TpuClient.class);
134+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
135+
TpuClient.ListNodesPagedResponse mockListNodesResponse =
136+
mock(TpuClient.ListNodesPagedResponse.class);
137+
when(mockTpuClient.listNodes(any(ListNodesRequest.class))).thenReturn(mockListNodesResponse);
138+
TpuClient.ListNodesPage mockListNodesPage = mock(TpuClient.ListNodesPage.class);
139+
when(mockListNodesResponse.getPage()).thenReturn(mockListNodesPage);
140+
when(mockListNodesPage.getValues()).thenReturn(mockListNodes);
141+
142+
TpuClient.ListNodesPage returnedListNodes = ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
143+
144+
assertThat(returnedListNodes.getValues()).isEqualTo(mockListNodes);
145+
verify(mockTpuClient, times(1)).listNodes(any(ListNodesRequest.class));
146+
}
147+
}
94148
}

0 commit comments

Comments
 (0)