1717package  tpu ;
1818
1919import  static  com .google .common .truth .Truth .assertThat ;
20+ import  static  org .junit .Assert .assertEquals ;
2021import  static  org .mockito .Mockito .any ;
2122import  static  org .mockito .Mockito .mock ;
2223import  static  org .mockito .Mockito .mockStatic ;
2526import  static  org .mockito .Mockito .when ;
2627
2728import  com .google .api .gax .longrunning .OperationFuture ;
29+ import  com .google .cloud .tpu .v2 .CreateNodeRequest ;
2830import  com .google .cloud .tpu .v2 .DeleteNodeRequest ;
2931import  com .google .cloud .tpu .v2 .GetNodeRequest ;
32+ import  com .google .cloud .tpu .v2 .ListNodesRequest ;
3033import  com .google .cloud .tpu .v2 .Node ;
3134import  com .google .cloud .tpu .v2 .TpuClient ;
3235import  com .google .cloud .tpu .v2 .TpuSettings ;
3336import  java .io .ByteArrayOutputStream ;
3437import  java .io .IOException ;
3538import  java .io .PrintStream ;
39+ import  java .util .Arrays ;
40+ import  java .util .List ;
3641import  java .util .concurrent .ExecutionException ;
3742import  org .junit .jupiter .api .BeforeAll ;
3843import  org .junit .jupiter .api .Test ;
@@ -47,6 +52,8 @@ public class TpuVmIT {
4752  private  static  final  String  PROJECT_ID  = "project-id" ;
4853  private  static  final  String  ZONE  = "asia-east1-c" ;
4954  private  static  final  String  NODE_NAME  = "test-tpu" ;
55+   private  static  final  String  TPU_TYPE  = "v2-8" ;
56+   private  static  final  String  TPU_SOFTWARE_VERSION  = "tpu-vm-tf-2.12.1" ;
5057  private  static  ByteArrayOutputStream  bout ;
5158
5259  @ BeforeAll 
@@ -55,21 +62,45 @@ public static void setUp() {
5562    System .setOut (new  PrintStream (bout ));
5663  }
5764
65+   @ Test 
66+   public  void  testCreateTpuVm () throws  Exception  {
67+     try  (MockedStatic <TpuClient > mockedTpuClient  = mockStatic (TpuClient .class )) {
68+       Node  mockNode  = mock (Node .class );
69+       TpuClient  mockTpuClient  = mock (TpuClient .class );
70+       OperationFuture  mockFuture  = mock (OperationFuture .class );
71+ 
72+       mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
73+           .thenReturn (mockTpuClient );
74+       when (mockTpuClient .createNodeAsync (any (CreateNodeRequest .class )))
75+           .thenReturn (mockFuture );
76+       when (mockFuture .get ()).thenReturn (mockNode );
77+ 
78+       Node  returnedNode  = CreateTpuVm .createTpuVm (
79+           PROJECT_ID , ZONE , NODE_NAME ,
80+           TPU_TYPE , TPU_SOFTWARE_VERSION );
81+ 
82+       verify (mockTpuClient , times (1 ))
83+           .createNodeAsync (any (CreateNodeRequest .class ));
84+       verify (mockFuture , times (1 )).get ();
85+       assertEquals (returnedNode , mockNode );
86+     }
87+   }
88+ 
5889  @ Test 
5990  public  void  testGetTpuVm () throws  IOException  {
6091    try  (MockedStatic <TpuClient > mockedTpuClient  = mockStatic (TpuClient .class )) {
6192      Node  mockNode  = mock (Node .class );
6293      TpuClient  mockClient  = mock (TpuClient .class );
63-       GetTpuVm  mockGetTpuVm  = mock (GetTpuVm .class );
6494
6595      mockedTpuClient .when (TpuClient ::create ).thenReturn (mockClient );
6696      when (mockClient .getNode (any (GetNodeRequest .class ))).thenReturn (mockNode );
6797
6898      Node  returnedNode  = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
6999
70-       verify (mockGetTpuVm , times (1 ))
71-           .getTpuVm ( PROJECT_ID ,  ZONE ,  NODE_NAME );
100+       verify (mockClient , times (1 ))
101+           .getNode ( any ( GetNodeRequest . class ) );
72102      assertThat (returnedNode ).isEqualTo (mockNode );
103+       verify (mockClient , times (1 )).close ();
73104    }
74105  }
75106
@@ -91,4 +122,27 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte
91122      verify (mockTpuClient , times (1 )).deleteNodeAsync (any (DeleteNodeRequest .class ));
92123    }
93124  }
125+ 
126+   @ Test 
127+   public  void  testListTpuVm () throws  IOException  {
128+     try  (MockedStatic <TpuClient > mockedTpuClient  = mockStatic (TpuClient .class )) {
129+       Node  mockNode1  = mock (Node .class );
130+       Node  mockNode2  = mock (Node .class );
131+       List <Node > mockListNodes  = Arrays .asList (mockNode1 , mockNode2 );
132+ 
133+       TpuClient  mockTpuClient  = mock (TpuClient .class );
134+       mockedTpuClient .when (TpuClient ::create ).thenReturn (mockTpuClient );
135+       TpuClient .ListNodesPagedResponse  mockListNodesResponse  =
136+           mock (TpuClient .ListNodesPagedResponse .class );
137+       when (mockTpuClient .listNodes (any (ListNodesRequest .class ))).thenReturn (mockListNodesResponse );
138+       TpuClient .ListNodesPage  mockListNodesPage  = mock (TpuClient .ListNodesPage .class );
139+       when (mockListNodesResponse .getPage ()).thenReturn (mockListNodesPage );
140+       when (mockListNodesPage .getValues ()).thenReturn (mockListNodes );
141+ 
142+       TpuClient .ListNodesPage  returnedListNodes  = ListTpuVms .listTpuVms (PROJECT_ID , ZONE );
143+ 
144+       assertThat (returnedListNodes .getValues ()).isEqualTo (mockListNodes );
145+       verify (mockTpuClient , times (1 )).listNodes (any (ListNodesRequest .class ));
146+     }
147+   }
94148}
0 commit comments