|
17 | 17 | package tpu; |
18 | 18 |
|
19 | 19 | import static com.google.common.truth.Truth.assertThat; |
| 20 | +import static org.junit.Assert.assertEquals; |
20 | 21 | import static org.mockito.Mockito.any; |
21 | 22 | import static org.mockito.Mockito.mock; |
22 | 23 | import static org.mockito.Mockito.mockStatic; |
|
27 | 28 | import com.google.api.gax.longrunning.OperationFuture; |
28 | 29 | import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest; |
29 | 30 | import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest; |
| 31 | +import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest; |
30 | 32 | import com.google.cloud.tpu.v2alpha1.QueuedResource; |
31 | | -import com.google.cloud.tpu.v2alpha1.QueuedResourceName; |
32 | 33 | import com.google.cloud.tpu.v2alpha1.TpuClient; |
33 | 34 | import com.google.cloud.tpu.v2alpha1.TpuSettings; |
34 | 35 | import java.io.ByteArrayOutputStream; |
@@ -62,48 +63,59 @@ public void setUp() { |
62 | 63 | @Test |
63 | 64 | public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { |
64 | 65 | try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) { |
| 66 | + QueuedResource mockQueuedResource = mock(QueuedResource.class); |
65 | 67 | TpuClient mockTpuClient = mock(TpuClient.class); |
| 68 | + OperationFuture mockFuture = mock(OperationFuture.class); |
| 69 | + |
66 | 70 | mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class))) |
67 | 71 | .thenReturn(mockTpuClient); |
68 | | - |
69 | | - OperationFuture mockFuture = mock(OperationFuture.class); |
70 | 72 | when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class))) |
71 | 73 | .thenReturn(mockFuture); |
72 | | - CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork( |
| 74 | + when(mockFuture.get()).thenReturn(mockQueuedResource); |
| 75 | + |
| 76 | + QueuedResource returnedQueuedResource = |
| 77 | + CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork( |
73 | 78 | PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, |
74 | 79 | TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME); |
75 | 80 |
|
76 | 81 | verify(mockTpuClient, times(1)) |
77 | 82 | .createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)); |
| 83 | + verify(mockFuture, times(1)).get(); |
| 84 | + assertEquals(returnedQueuedResource, mockQueuedResource); |
78 | 85 | } |
79 | 86 | } |
80 | 87 |
|
81 | 88 | @Test |
82 | 89 | public void testGetQueuedResource() throws IOException { |
83 | 90 | try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) { |
84 | | - QueuedResource mockQueuedResource = mock(QueuedResource.class); |
85 | | - mockedTpuClient.when(TpuClient::create).thenReturn(mock(TpuClient.class)); |
86 | | - when(mock(TpuClient.class) |
87 | | - .getQueuedResource(any(QueuedResourceName.class))).thenReturn(mockQueuedResource); |
| 91 | + TpuClient mockClient = mock(TpuClient.class); |
88 | 92 | GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class); |
| 93 | + QueuedResource mockQueuedResource = mock(QueuedResource.class); |
| 94 | + |
| 95 | + mockedTpuClient.when(TpuClient::create).thenReturn(mockClient); |
| 96 | + when(mockClient.getQueuedResource(any(GetQueuedResourceRequest.class))) |
| 97 | + .thenReturn(mockQueuedResource); |
89 | 98 |
|
90 | | - GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); |
| 99 | + QueuedResource returnedQueuedResource = |
| 100 | + GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); |
91 | 101 |
|
92 | 102 | verify(mockGetQueuedResource, times(1)) |
93 | 103 | .getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); |
| 104 | + assertEquals(returnedQueuedResource, mockQueuedResource); |
94 | 105 | } |
95 | 106 | } |
96 | 107 |
|
97 | 108 | @Test |
98 | 109 | public void testDeleteTpuVm() { |
99 | 110 | try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) { |
100 | 111 | TpuClient mockTpuClient = mock(TpuClient.class); |
| 112 | + OperationFuture mockFuture = mock(OperationFuture.class); |
| 113 | + |
101 | 114 | mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class))) |
102 | 115 | .thenReturn(mockTpuClient); |
103 | | - |
104 | | - OperationFuture mockFuture = mock(OperationFuture.class); |
105 | 116 | when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class))) |
106 | 117 | .thenReturn(mockFuture); |
| 118 | + |
107 | 119 | DeleteForceQueuedResource.deleteForceQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME); |
108 | 120 | String output = bout.toString(); |
109 | 121 |
|
|
0 commit comments