1616
1717package tpu ;
1818
19- import static com .google .common .truth .Truth .assertThat ;
20- import static com .google .common .truth .Truth .assertWithMessage ;
21- import static org .junit .Assert .assertNotNull ;
19+ import static org .mockito .Mockito .any ;
20+ import static org .mockito .Mockito .mock ;
21+ import static org .mockito .Mockito .times ;
22+ import static org .mockito .Mockito .verify ;
23+ import static org .mockito .Mockito .when ;
2224
23- import com .google .api .gax .rpc .NotFoundException ;
25+ import com .google .api .gax .longrunning .OperationFuture ;
26+ import com .google .cloud .tpu .v2 .CreateNodeRequest ;
27+ import com .google .cloud .tpu .v2 .DeleteNodeRequest ;
2428import com .google .cloud .tpu .v2 .Node ;
29+ import com .google .cloud .tpu .v2 .NodeName ;
30+ import com .google .cloud .tpu .v2 .TpuClient ;
31+ import com .google .cloud .tpu .v2 .TpuSettings ;
2532import java .io .IOException ;
26- import java .util .UUID ;
2733import java .util .concurrent .ExecutionException ;
2834import java .util .concurrent .TimeUnit ;
29- import org .junit .jupiter .api .AfterAll ;
30- import org .junit .jupiter .api .Assertions ;
31- import org .junit .jupiter .api .BeforeAll ;
3235import org .junit .jupiter .api .MethodOrderer ;
3336import org .junit .jupiter .api .Order ;
3437import org .junit .jupiter .api .Test ;
3538import org .junit .jupiter .api .TestMethodOrder ;
3639import org .junit .jupiter .api .Timeout ;
3740import org .junit .runner .RunWith ;
38- import org .junit .runners .JUnit4 ;
41+ import org .mockito .MockedStatic ;
42+ import org .mockito .Mockito ;
43+ import org .powermock .modules .junit4 .PowerMockRunner ;
3944
40- @ RunWith (JUnit4 .class )
41- @ Timeout (value = 15 , unit = TimeUnit .MINUTES )
45+ @ RunWith (PowerMockRunner .class )
46+ @ Timeout (value = 3 , unit = TimeUnit .MINUTES )
4247@ TestMethodOrder (MethodOrderer .OrderAnnotation .class )
4348public class TpuVmIT {
44- private static final String PROJECT_ID = System . getenv ( "GOOGLE_CLOUD_PROJECT" ) ;
49+ private static final String PROJECT_ID = "project-id" ;
4550 private static final String ZONE = "asia-east1-c" ;
46- private static final String NODE_NAME = "test-tpu-" + UUID . randomUUID () ;
51+ private static final String NODE_NAME = "test-tpu" ;
4752 private static final String TPU_TYPE = "v2-8" ;
4853 private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1" ;
49- private static final String NODE_PATH_NAME =
50- String .format ("projects/%s/locations/%s/nodes/%s" , PROJECT_ID , ZONE , NODE_NAME );
5154
52- public static void requireEnvVar (String envVarName ) {
53- assertWithMessage (String .format ("Missing environment variable '%s' " , envVarName ))
54- .that (System .getenv (envVarName )).isNotEmpty ();
55- }
56-
57- @ BeforeAll
58- public static void setUp () {
59- requireEnvVar ("GOOGLE_APPLICATION_CREDENTIALS" );
60- requireEnvVar ("GOOGLE_CLOUD_PROJECT" );
61- }
55+ @ Test
56+ @ Order (1 )
57+ public void testCreateTpuVm () throws IOException , ExecutionException , InterruptedException {
58+ TpuClient mockTpuClient = mock (TpuClient .class );
59+ try (MockedStatic <TpuClient > mockedTpuClient = Mockito .mockStatic (TpuClient .class )) {
60+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
61+ .thenReturn (mockTpuClient );
6262
63- @ AfterAll
64- public static void cleanup () throws Exception {
65- DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , NODE_NAME );
63+ OperationFuture mockFuture = mock (OperationFuture .class );
64+ when (mockTpuClient .createNodeAsync (any (CreateNodeRequest .class )))
65+ .thenReturn (mockFuture );
66+ CreateTpuVm .createTpuVm (PROJECT_ID , ZONE , NODE_NAME , TPU_TYPE , TPU_SOFTWARE_VERSION );
6667
67- // Test that TPUs is deleted
68- Assertions .assertThrows (
69- NotFoundException .class ,
70- () -> GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME ));
68+ verify (mockTpuClient , times (1 )).createNodeAsync (any (CreateNodeRequest .class ));
69+ }
7170 }
7271
7372 @ Test
74- @ Order (1 )
75- public void testCreateTpuVm () throws IOException , ExecutionException , InterruptedException {
73+ public void testGetTpuVm () throws IOException {
74+ GetTpuVm mockGetTpuVm = mock (GetTpuVm .class );
75+ Node mockNode = mock (Node .class );
76+ try (MockedStatic <TpuClient > mockedTpuClient = Mockito .mockStatic (TpuClient .class )) {
77+ mockedTpuClient .when (TpuClient ::create ).thenReturn (mock (TpuClient .class ));
78+ when (mock (TpuClient .class ).getNode (any (NodeName .class ))).thenReturn (mockNode );
7679
77- Node node = CreateTpuVm .createTpuVm (
78- PROJECT_ID , ZONE , NODE_NAME , TPU_TYPE , TPU_SOFTWARE_VERSION );
80+ GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
7981
80- assertNotNull (node );
81- assertThat (node .getName ().equals (NODE_NAME ));
82- assertThat (node .getAcceleratorType ().equals (TPU_TYPE ));
82+ // Assertions
83+ verify (mockGetTpuVm , times (1 ))
84+ .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
85+ }
8386 }
8487
8588 @ Test
86- @ Order (2 )
87- public void testGetTpuVm () throws IOException {
88- Node node = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
89+ public void testDeleteTpuVm () throws IOException , ExecutionException , InterruptedException {
90+ TpuClient mockTpuClient = mock (TpuClient .class );
91+ try (MockedStatic <TpuClient > mockedTpuClient = Mockito .mockStatic (TpuClient .class )) {
92+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
93+ .thenReturn (mockTpuClient );
94+
95+ OperationFuture mockFuture = mock (OperationFuture .class );
96+ when (mockTpuClient .deleteNodeAsync (any (DeleteNodeRequest .class )))
97+ .thenReturn (mockFuture );
98+ DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , NODE_NAME );
8999
90- assertNotNull ( node );
91- assertThat ( node . getName ()). isEqualTo ( NODE_PATH_NAME );
100+ verify ( mockTpuClient , times ( 1 )). deleteNodeAsync ( any ( DeleteNodeRequest . class ) );
101+ }
92102 }
93103}
0 commit comments