Skip to content

Commit 16f9356

Browse files
Fixed test
1 parent 428a2ef commit 16f9356

File tree

3 files changed

+28
-31
lines changed

3 files changed

+28
-31
lines changed

tpu/src/main/java/tpu/ListTpuVms.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
//[START tpu_vm_list]
2020
import com.google.cloud.tpu.v2.ListNodesRequest;
2121
import com.google.cloud.tpu.v2.TpuClient;
22-
import com.google.cloud.tpu.v2.TpuClient.ListNodesPagedResponse;
2322
import java.io.IOException;
2423

2524
public class ListTpuVms {
@@ -37,15 +36,16 @@ public static void main(String[] args) throws IOException {
3736
}
3837

3938
// Lists TPU VMs in the specified zone.
40-
public static ListNodesPagedResponse listTpuVms(String projectId, String zone)
39+
public static TpuClient.ListNodesPage listTpuVms(String projectId, String zone)
4140
throws IOException {
4241
// Initialize client that will be used to send requests. This client only needs to be created
4342
// once, and can be reused for multiple requests.
4443
try (TpuClient tpuClient = TpuClient.create()) {
4544
String parent = String.format("projects/%s/locations/%s", projectId, zone);
4645

4746
ListNodesRequest request = ListNodesRequest.newBuilder().setParent(parent).build();
48-
return tpuClient.listNodes(request);
47+
48+
return tpuClient.listNodes(request).getPage();
4949
}
5050
}
5151
}

tpu/src/test/java/tpu/GetTpuVmIT.java renamed to tpu/src/test/java/tpu/ListTpuVmsIT.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,53 @@
1616

1717
package tpu;
1818

19+
import static com.google.common.truth.Truth.assertThat;
1920
import static org.mockito.Mockito.any;
2021
import static org.mockito.Mockito.mock;
2122
import static org.mockito.Mockito.mockStatic;
2223
import static org.mockito.Mockito.times;
2324
import static org.mockito.Mockito.verify;
2425
import static org.mockito.Mockito.when;
2526

27+
import com.google.cloud.tpu.v2.ListNodesRequest;
2628
import com.google.cloud.tpu.v2.Node;
27-
import com.google.cloud.tpu.v2.NodeName;
2829
import com.google.cloud.tpu.v2.TpuClient;
2930
import java.io.IOException;
30-
import org.junit.jupiter.api.Test;
31+
import java.util.Arrays;
32+
import java.util.List;
33+
import org.junit.Test;
3134
import org.junit.jupiter.api.Timeout;
3235
import org.junit.runner.RunWith;
3336
import org.junit.runners.JUnit4;
3437
import org.mockito.MockedStatic;
3538

3639
@RunWith(JUnit4.class)
37-
@Timeout(value = 5)
38-
public class GetTpuVmIT {
40+
@Timeout(value = 3)
41+
public class ListTpuVmsIT {
3942
private static final String PROJECT_ID = "project-id";
4043
private static final String ZONE = "asia-east1-c";
41-
private static final String NODE_NAME = "test-tpu";
4244

4345
@Test
44-
public void testGetTpuVm() throws IOException {
46+
public void testListTpuVm() throws IOException {
4547
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
46-
Node mockNode = mock(Node.class);
47-
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
48-
when(mock(TpuClient.class).getNode(any(NodeName.class))).thenReturn(mockNode);
49-
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
48+
Node mockNode1 = mock(Node.class);
49+
Node mockNode2 = mock(Node.class);
50+
List<Node> mockListNodes = Arrays.asList(mockNode1, mockNode2);
5051

51-
GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
52+
TpuClient mockTpuClient = mock(TpuClient.class);
53+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
54+
TpuClient.ListNodesPagedResponse mockListNodesResponse =
55+
mock(TpuClient.ListNodesPagedResponse.class);
56+
when(mockTpuClient.listNodes(any(ListNodesRequest.class))).thenReturn(mockListNodesResponse);
57+
TpuClient.ListNodesPage mockListNodesPage = mock(TpuClient.ListNodesPage.class);
58+
when(mockListNodesResponse.getPage()).thenReturn(mockListNodesPage);
59+
when(mockListNodesPage.getValues()).thenReturn(mockListNodes);
5260

53-
verify(mockGetTpuVm, times(1)).getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
61+
TpuClient.ListNodesPage returnedListNodes = ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
62+
63+
assertThat(returnedListNodes.getValues()).isEqualTo(mockListNodes);
64+
verify(mockTpuClient, times(1)).listNodes(any(ListNodesRequest.class));
5465
}
5566
}
56-
}
67+
}
68+

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import com.google.api.gax.longrunning.OperationFuture;
2828
import com.google.cloud.tpu.v2.DeleteNodeRequest;
2929
import com.google.cloud.tpu.v2.GetNodeRequest;
30-
import com.google.cloud.tpu.v2.LocationName;
3130
import com.google.cloud.tpu.v2.Node;
3231
import com.google.cloud.tpu.v2.TpuClient;
3332
import com.google.cloud.tpu.v2.TpuSettings;
@@ -92,18 +91,4 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte
9291
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));
9392
}
9493
}
95-
96-
@Test
97-
public void testListTpuVm() throws IOException {
98-
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
99-
TpuClient.ListNodesPagedResponse mockListNodes = mock(TpuClient.ListNodesPagedResponse.class);
100-
mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class));
101-
when(mock(TpuClient.class).listNodes(any(LocationName.class))).thenReturn(mockListNodes);
102-
ListTpuVms mockListTpuVms = mock(ListTpuVms.class);
103-
104-
ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
105-
106-
verify(mockListTpuVms, times(1)).listTpuVms(PROJECT_ID, ZONE);
107-
}
108-
}
10994
}

0 commit comments

Comments
 (0)