1616
1717package tpu ;
1818
19- import static com .google .common .truth .Truth .assertThat ;
20- import static com .google .common .truth .Truth .assertWithMessage ;
21- import static org .junit .Assert .assertNotNull ;
19+ import static org .mockito .Mockito .any ;
20+ import static org .mockito .Mockito .mock ;
21+ import static org .mockito .Mockito .mockStatic ;
22+ import static org .mockito .Mockito .times ;
23+ import static org .mockito .Mockito .verify ;
24+ import static org .mockito .Mockito .when ;
2225
23- import com .google .api .gax .rpc .NotFoundException ;
24- import com .google .cloud .tpu .v2 .Node ;
26+ import com .google .api .gax .longrunning .OperationFuture ;
27+ import com .google .cloud .tpu .v2 .CreateNodeRequest ;
28+ import com .google .cloud .tpu .v2 .DeleteNodeRequest ;
29+ import com .google .cloud .tpu .v2 .LocationName ;
2530import com .google .cloud .tpu .v2 .TpuClient ;
31+ import com .google .cloud .tpu .v2 .TpuSettings ;
2632import java .io .IOException ;
27- import java .util .UUID ;
2833import java .util .concurrent .ExecutionException ;
29- import java .util .concurrent .TimeUnit ;
30- import org .junit .Assert ;
31- import org .junit .jupiter .api .AfterAll ;
32- import org .junit .jupiter .api .Assertions ;
33- import org .junit .jupiter .api .BeforeAll ;
34- import org .junit .jupiter .api .MethodOrderer ;
35- import org .junit .jupiter .api .Order ;
36- import org .junit .jupiter .api .Test ;
37- import org .junit .jupiter .api .TestMethodOrder ;
34+ import org .junit .Test ;
3835import org .junit .jupiter .api .Timeout ;
3936import org .junit .runner .RunWith ;
4037import org .junit .runners .JUnit4 ;
38+ import org .mockito .MockedStatic ;
4139
4240@ RunWith (JUnit4 .class )
43- @ Timeout (value = 15 , unit = TimeUnit .MINUTES )
44- @ TestMethodOrder (MethodOrderer .OrderAnnotation .class )
41+ @ Timeout (value = 3 )
4542public class TpuVmIT {
46- private static final String PROJECT_ID = System . getenv ( "GOOGLE_CLOUD_PROJECT" ) ;
43+ private static final String PROJECT_ID = "project-id" ;
4744 private static final String ZONE = "asia-east1-c" ;
48- private static final String NODE_NAME = "test-tpu-" + UUID . randomUUID () ;
45+ private static final String NODE_NAME = "test-tpu" ;
4946 private static final String TPU_TYPE = "v2-8" ;
5047 private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1" ;
51- private static final String NODE_PATH_NAME =
52- String .format ("projects/%s/locations/%s/nodes/%s" , PROJECT_ID , ZONE , NODE_NAME );
53-
54- public static void requireEnvVar (String envVarName ) {
55- assertWithMessage (String .format ("Missing environment variable '%s' " , envVarName ))
56- .that (System .getenv (envVarName )).isNotEmpty ();
57- }
58-
59- @ BeforeAll
60- public static void setUp () {
61- requireEnvVar ("GOOGLE_APPLICATION_CREDENTIALS" );
62- requireEnvVar ("GOOGLE_CLOUD_PROJECT" );
63- }
64-
65- @ AfterAll
66- public static void cleanup () throws Exception {
67- DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , NODE_NAME );
68-
69- // Test that TPUs is deleted
70- Assertions .assertThrows (
71- NotFoundException .class ,
72- () -> GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME ));
73- }
7448
7549 @ Test
76- @ Order (1 )
77- public void testCreateTpuVm () throws IOException , ExecutionException , InterruptedException {
50+ public void testCreateTpuVm () throws Exception {
51+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
52+ TpuClient mockTpuClient = mock (TpuClient .class );
53+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
54+ .thenReturn (mockTpuClient );
7855
79- Node node = CreateTpuVm .createTpuVm (
80- PROJECT_ID , ZONE , NODE_NAME , TPU_TYPE , TPU_SOFTWARE_VERSION );
56+ OperationFuture mockFuture = mock (OperationFuture .class );
57+ when (mockTpuClient .createNodeAsync (any (CreateNodeRequest .class )))
58+ .thenReturn (mockFuture );
59+ CreateTpuVm .createTpuVm (
60+ PROJECT_ID , ZONE , NODE_NAME ,
61+ TPU_TYPE , TPU_SOFTWARE_VERSION );
8162
82- assertNotNull (node );
83- assertThat (node .getName ().equals (NODE_NAME ));
84- assertThat (node .getAcceleratorType ().equals (TPU_TYPE ));
63+ verify (mockTpuClient , times (1 )).createNodeAsync (any (CreateNodeRequest .class ));
64+ }
8565 }
8666
8767 @ Test
88- @ Order (2 )
89- public void testGetTpuVm () throws IOException {
90- Node node = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
68+ public void testDeleteTpuVm () throws IOException , ExecutionException , InterruptedException {
69+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
70+ TpuClient mockTpuClient = mock (TpuClient .class );
71+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
72+ .thenReturn (mockTpuClient );
73+
74+ OperationFuture mockFuture = mock (OperationFuture .class );
75+ when (mockTpuClient .deleteNodeAsync (any (DeleteNodeRequest .class )))
76+ .thenReturn (mockFuture );
77+ DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , NODE_NAME );
9178
92- assertNotNull ( node );
93- assertThat ( node . getName ()). isEqualTo ( NODE_PATH_NAME );
79+ verify ( mockTpuClient , times ( 1 )). deleteNodeAsync ( any ( DeleteNodeRequest . class ) );
80+ }
9481 }
9582
9683 @ Test
97- @ Order (2 )
9884 public void testListTpuVm () throws IOException {
99- TpuClient .ListNodesPagedResponse nodesList = ListTpuVms .listTpuVms (PROJECT_ID , ZONE );
85+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
86+ TpuClient .ListNodesPagedResponse mockListNodes = mock (TpuClient .ListNodesPagedResponse .class );
87+ mockedTpuClient .when (TpuClient ::create ).thenReturn (mock (TpuClient .class ));
88+ when (mock (TpuClient .class ).listNodes (any (LocationName .class ))).thenReturn (mockListNodes );
89+ ListTpuVms mockListTpuVms = mock (ListTpuVms .class );
90+
91+ ListTpuVms .listTpuVms (PROJECT_ID , ZONE );
10092
101- assertNotNull (nodesList );
102- for (Node node : nodesList .iterateAll ()) {
103- Assert .assertTrue (node .getName ().contains ("test-tpu" ));
93+ verify (mockListTpuVms , times (1 )).listTpuVms (PROJECT_ID , ZONE );
10494 }
10595 }
106- }
96+ }
0 commit comments