1717package tpu ;
1818
1919import static com .google .common .truth .Truth .assertThat ;
20- import static com .google .common .truth .Truth .assertWithMessage ;
21- import static org .junit .Assert .assertNotNull ;
20+ import static org .junit .Assert .assertEquals ;
21+ import static org .mockito .Mockito .any ;
22+ import static org .mockito .Mockito .mock ;
23+ import static org .mockito .Mockito .mockStatic ;
24+ import static org .mockito .Mockito .times ;
25+ import static org .mockito .Mockito .verify ;
26+ import static org .mockito .Mockito .when ;
2227
23- import com .google .api .gax .rpc .NotFoundException ;
28+ import com .google .api .gax .longrunning .OperationFuture ;
29+ import com .google .cloud .tpu .v2 .AcceleratorConfig ;
30+ import com .google .cloud .tpu .v2 .CreateNodeRequest ;
31+ import com .google .cloud .tpu .v2 .DeleteNodeRequest ;
32+ import com .google .cloud .tpu .v2 .GetNodeRequest ;
2433import com .google .cloud .tpu .v2 .Node ;
34+ import com .google .cloud .tpu .v2 .TpuClient ;
35+ import com .google .cloud .tpu .v2 .TpuSettings ;
36+ import java .io .ByteArrayOutputStream ;
2537import java .io .IOException ;
26- import java .util . UUID ;
38+ import java .io . PrintStream ;
2739import java .util .concurrent .ExecutionException ;
28- import java .util .concurrent .TimeUnit ;
29- import org .junit .jupiter .api .AfterAll ;
30- import org .junit .jupiter .api .Assertions ;
3140import org .junit .jupiter .api .BeforeAll ;
32- import org .junit .jupiter .api .MethodOrderer ;
33- import org .junit .jupiter .api .Order ;
3441import org .junit .jupiter .api .Test ;
35- import org .junit .jupiter .api .TestMethodOrder ;
3642import org .junit .jupiter .api .Timeout ;
3743import org .junit .runner .RunWith ;
3844import org .junit .runners .JUnit4 ;
45+ import org .mockito .MockedStatic ;
3946
4047@ RunWith (JUnit4 .class )
41- @ Timeout (value = 15 , unit = TimeUnit .MINUTES )
42- @ TestMethodOrder (MethodOrderer .OrderAnnotation .class )
48+ @ Timeout (value = 3 )
4349public class TpuVmIT {
44- private static final String PROJECT_ID = System . getenv ( "GOOGLE_CLOUD_PROJECT" ) ;
50+ private static final String PROJECT_ID = "project-id" ;
4551 private static final String ZONE = "asia-east1-c" ;
46- private static final String NODE_NAME = "test-tpu-" + UUID . randomUUID () ;
52+ private static final String NODE_NAME = "test-tpu" ;
4753 private static final String TPU_TYPE = "v2-8" ;
48- private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1" ;
49- private static final String NODE_PATH_NAME =
50- String .format ("projects/%s/locations/%s/nodes/%s" , PROJECT_ID , ZONE , NODE_NAME );
51-
52- public static void requireEnvVar (String envVarName ) {
53- assertWithMessage (String .format ("Missing environment variable '%s' " , envVarName ))
54- .that (System .getenv (envVarName )).isNotEmpty ();
55- }
54+ private static final AcceleratorConfig .Type ACCELERATOR_TYPE = AcceleratorConfig .Type .V2 ;
55+ private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1" ;
56+ private static final String TOPOLOGY = "2x2" ;
57+ private static ByteArrayOutputStream bout ;
5658
5759 @ BeforeAll
5860 public static void setUp () {
59- requireEnvVar ( "GOOGLE_APPLICATION_CREDENTIALS" );
60- requireEnvVar ( "GOOGLE_CLOUD_PROJECT" );
61+ bout = new ByteArrayOutputStream ( );
62+ System . setOut ( new PrintStream ( bout ) );
6163 }
6264
63- @ AfterAll
64- public static void cleanup () throws Exception {
65- DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , NODE_NAME );
65+ @ Test
66+ public void testCreateTpuVm () throws Exception {
67+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
68+ Node mockNode = mock (Node .class );
69+ TpuClient mockTpuClient = mock (TpuClient .class );
70+ OperationFuture mockFuture = mock (OperationFuture .class );
71+
72+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
73+ .thenReturn (mockTpuClient );
74+ when (mockTpuClient .createNodeAsync (any (CreateNodeRequest .class )))
75+ .thenReturn (mockFuture );
76+ when (mockFuture .get ()).thenReturn (mockNode );
77+
78+ Node returnedNode = CreateTpuVm .createTpuVm (
79+ PROJECT_ID , ZONE , NODE_NAME ,
80+ TPU_TYPE , TPU_SOFTWARE_VERSION );
6681
67- // Test that TPUs is deleted
68- Assertions .assertThrows (
69- NotFoundException .class ,
70- () -> GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME ));
82+ verify (mockTpuClient , times (1 ))
83+ .createNodeAsync (any (CreateNodeRequest .class ));
84+ verify (mockFuture , times (1 )).get ();
85+ assertEquals (returnedNode , mockNode );
86+ }
7187 }
7288
7389 @ Test
74- @ Order (1 )
75- public void testCreateTpuVm () throws IOException , ExecutionException , InterruptedException {
76- Node node = CreateTpuVm .createTpuVm (
77- PROJECT_ID , ZONE , NODE_NAME , TPU_TYPE , TPU_SOFTWARE_VERSION );
78-
79- assertNotNull (node );
80- assertThat (node .getName ().equals (NODE_NAME ));
81- assertThat (node .getAcceleratorType ().equals (TPU_TYPE ));
90+ public void testGetTpuVm () throws IOException {
91+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
92+ Node mockNode = mock (Node .class );
93+ TpuClient mockClient = mock (TpuClient .class );
94+
95+ mockedTpuClient .when (TpuClient ::create ).thenReturn (mockClient );
96+ when (mockClient .getNode (any (GetNodeRequest .class ))).thenReturn (mockNode );
97+
98+ Node returnedNode = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
99+
100+ verify (mockClient , times (1 ))
101+ .getNode (any (GetNodeRequest .class ));
102+ assertThat (returnedNode ).isEqualTo (mockNode );
103+ verify (mockClient , times (1 )).close ();
104+ }
82105 }
83106
84107 @ Test
85- @ Order (2 )
86- public void testGetTpuVm () throws IOException {
87- Node node = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
108+ public void testDeleteTpuVm () throws IOException , ExecutionException , InterruptedException {
109+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
110+ TpuClient mockTpuClient = mock (TpuClient .class );
111+ OperationFuture mockFuture = mock (OperationFuture .class );
112+
113+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
114+ .thenReturn (mockTpuClient );
115+ when (mockTpuClient .deleteNodeAsync (any (DeleteNodeRequest .class )))
116+ .thenReturn (mockFuture );
117+
118+ DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , NODE_NAME );
119+ String output = bout .toString ();
120+
121+ assertThat (output ).contains ("TPU VM deleted" );
122+ verify (mockTpuClient , times (1 )).deleteNodeAsync (any (DeleteNodeRequest .class ));
123+ }
124+ }
125+
126+ @ Test
127+ public void testCreateTpuVmWithTopologyFlag ()
128+ throws IOException , ExecutionException , InterruptedException {
129+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
130+ Node mockNode = mock (Node .class );
131+ TpuClient mockTpuClient = mock (TpuClient .class );
132+ OperationFuture mockFuture = mock (OperationFuture .class );
133+
134+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
135+ .thenReturn (mockTpuClient );
136+ when (mockTpuClient .createNodeAsync (any (CreateNodeRequest .class )))
137+ .thenReturn (mockFuture );
138+ when (mockFuture .get ()).thenReturn (mockNode );
139+ Node returnedNode = CreateTpuWithTopologyFlag .createTpuWithTopologyFlag (
140+ PROJECT_ID , ZONE , NODE_NAME , ACCELERATOR_TYPE ,
141+ TPU_SOFTWARE_VERSION , TOPOLOGY );
88142
89- assertNotNull (node );
90- assertThat (node .getName ()).isEqualTo (NODE_PATH_NAME );
143+ verify (mockTpuClient , times (1 ))
144+ .createNodeAsync (any (CreateNodeRequest .class ));
145+ verify (mockFuture , times (1 )).get ();
146+ assertEquals (returnedNode , mockNode );
147+ }
91148 }
92149}
0 commit comments