1717package tpu ;
1818
1919import static com .google .common .truth .Truth .assertThat ;
20- import static com .google .common .truth .Truth .assertWithMessage ;
20+ import static org .junit .Assert .assertEquals ;
21+ import static org .mockito .Mockito .any ;
22+ import static org .mockito .Mockito .mock ;
23+ import static org .mockito .Mockito .mockStatic ;
24+ import static org .mockito .Mockito .times ;
25+ import static org .mockito .Mockito .verify ;
26+ import static org .mockito .Mockito .when ;
2127
28+ import com .google .api .gax .longrunning .OperationFuture ;
29+ import com .google .cloud .tpu .v2alpha1 .CreateQueuedResourceRequest ;
30+ import com .google .cloud .tpu .v2alpha1 .DeleteQueuedResourceRequest ;
31+ import com .google .cloud .tpu .v2alpha1 .GetQueuedResourceRequest ;
2232import com .google .cloud .tpu .v2alpha1 .QueuedResource ;
23- import java .util .UUID ;
24- import java .util .concurrent .TimeUnit ;
33+ import com .google .cloud .tpu .v2alpha1 .TpuClient ;
34+ import com .google .cloud .tpu .v2alpha1 .TpuSettings ;
35+ import java .io .ByteArrayOutputStream ;
36+ import java .io .IOException ;
37+ import java .io .PrintStream ;
38+ import org .junit .Before ;
2539import org .junit .Test ;
26- import org .junit .jupiter .api .AfterAll ;
27- import org .junit .jupiter .api .BeforeAll ;
2840import org .junit .jupiter .api .Timeout ;
2941import org .junit .runner .RunWith ;
3042import org .junit .runners .JUnit4 ;
43+ import org .mockito .MockedStatic ;
3144
3245@ RunWith (JUnit4 .class )
33- @ Timeout (value = 6 , unit = TimeUnit . MINUTES )
46+ @ Timeout (value = 3 )
3447public class QueuedResourceIT {
35- private static final String PROJECT_ID = System . getenv ( "GOOGLE_CLOUD_PROJECT" ) ;
48+ private static final String PROJECT_ID = "project-id" ;
3649 private static final String ZONE = "europe-west4-a" ;
37- private static final String NODE_NAME = "test-tpu-queued-resource-network-" + UUID . randomUUID () ;
50+ private static final String NODE_NAME = "test-tpu" ;
3851 private static final String TPU_TYPE = "v2-8" ;
3952 private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1" ;
40- private static final String QUEUED_RESOURCE_NAME = "queued-resource-network-" + UUID . randomUUID () ;
53+ private static final String QUEUED_RESOURCE_NAME = "queued-resource" ;
4154 private static final String NETWORK_NAME = "default" ;
55+ private ByteArrayOutputStream bout ;
4256
43- public static void requireEnvVar (String envVarName ) {
44- assertWithMessage (String .format ("Missing environment variable '%s' " , envVarName ))
45- .that (System .getenv (envVarName )).isNotEmpty ();
57+ @ Before
58+ public void setUp () {
59+ bout = new ByteArrayOutputStream ();
60+ System .setOut (new PrintStream (bout ));
4661 }
4762
48- @ BeforeAll
49- public static void setUp () {
50- requireEnvVar ("GOOGLE_APPLICATION_CREDENTIALS" );
51- requireEnvVar ("GOOGLE_CLOUD_PROJECT" );
63+ @ Test
64+ public void testCreateQueuedResourceWithSpecifiedNetwork () throws Exception {
65+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
66+ QueuedResource mockQueuedResource = mock (QueuedResource .class );
67+ TpuClient mockTpuClient = mock (TpuClient .class );
68+ OperationFuture mockFuture = mock (OperationFuture .class );
69+
70+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
71+ .thenReturn (mockTpuClient );
72+ when (mockTpuClient .createQueuedResourceAsync (any (CreateQueuedResourceRequest .class )))
73+ .thenReturn (mockFuture );
74+ when (mockFuture .get ()).thenReturn (mockQueuedResource );
75+
76+ QueuedResource returnedQueuedResource =
77+ CreateQueuedResourceWithNetwork .createQueuedResourceWithNetwork (
78+ PROJECT_ID , ZONE , QUEUED_RESOURCE_NAME , NODE_NAME ,
79+ TPU_TYPE , TPU_SOFTWARE_VERSION , NETWORK_NAME );
80+
81+ verify (mockTpuClient , times (1 ))
82+ .createQueuedResourceAsync (any (CreateQueuedResourceRequest .class ));
83+ verify (mockFuture , times (1 )).get ();
84+ assertEquals (returnedQueuedResource , mockQueuedResource );
85+ }
5286 }
5387
54- @ AfterAll
55- public static void cleanup () {
56- DeleteForceQueuedResource .deleteForceQueuedResource (PROJECT_ID , ZONE , QUEUED_RESOURCE_NAME );
88+ @ Test
89+ public void testGetQueuedResource () throws IOException {
90+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
91+ TpuClient mockClient = mock (TpuClient .class );
92+ GetQueuedResource mockGetQueuedResource = mock (GetQueuedResource .class );
93+ QueuedResource mockQueuedResource = mock (QueuedResource .class );
94+
95+ mockedTpuClient .when (TpuClient ::create ).thenReturn (mockClient );
96+ when (mockClient .getQueuedResource (any (GetQueuedResourceRequest .class )))
97+ .thenReturn (mockQueuedResource );
98+
99+ QueuedResource returnedQueuedResource =
100+ GetQueuedResource .getQueuedResource (PROJECT_ID , ZONE , NODE_NAME );
101+
102+ verify (mockGetQueuedResource , times (1 ))
103+ .getQueuedResource (PROJECT_ID , ZONE , NODE_NAME );
104+ assertEquals (returnedQueuedResource , mockQueuedResource );
105+ }
57106 }
58107
59108 @ Test
60- public void testCreateQueuedResourceWithSpecifiedNetwork () throws Exception {
109+ public void testDeleteTpuVm () {
110+ try (MockedStatic <TpuClient > mockedTpuClient = mockStatic (TpuClient .class )) {
111+ TpuClient mockTpuClient = mock (TpuClient .class );
112+ OperationFuture mockFuture = mock (OperationFuture .class );
113+
114+ mockedTpuClient .when (() -> TpuClient .create (any (TpuSettings .class )))
115+ .thenReturn (mockTpuClient );
116+ when (mockTpuClient .deleteQueuedResourceAsync (any (DeleteQueuedResourceRequest .class )))
117+ .thenReturn (mockFuture );
61118
62- QueuedResource queuedResource = CreateQueuedResourceWithNetwork .createQueuedResourceWithNetwork (
63- PROJECT_ID , ZONE , QUEUED_RESOURCE_NAME , NODE_NAME ,
64- TPU_TYPE , TPU_SOFTWARE_VERSION , NETWORK_NAME );
119+ DeleteForceQueuedResource .deleteForceQueuedResource (PROJECT_ID , ZONE , QUEUED_RESOURCE_NAME );
120+ String output = bout .toString ();
65121
66- assertThat (queuedResource .getTpu ().getNodeSpec (0 ).getNode ().getName ()).isEqualTo (NODE_NAME );
67- assertThat (queuedResource .getTpu ().getNodeSpec (0 ).getNode ().getNetworkConfig ().getNetwork ()
68- .contains (NETWORK_NAME ));
69- assertThat (queuedResource .getTpu ().getNodeSpec (0 ).getNode ().getNetworkConfig ().getSubnetwork ()
70- .contains (NETWORK_NAME ));
122+ assertThat (output ).contains ("Deleted Queued Resource:" );
123+ verify (mockTpuClient , times (1 ))
124+ .deleteQueuedResourceAsync (any (DeleteQueuedResourceRequest .class ));
125+ }
71126 }
72127}
0 commit comments