1616
1717package tpu ;
1818
19- import static com .google .cloud .tpu .v2 .Node .State .READY ;
20- import static com .google .cloud .tpu .v2 .Node .State .STOPPED ;
2119import static com .google .common .truth .Truth .assertThat ;
2220import static com .google .common .truth .Truth .assertWithMessage ;
2321import static org .junit .Assert .assertNotNull ;
2422
2523import com .google .api .gax .rpc .NotFoundException ;
2624import com .google .cloud .tpu .v2 .Node ;
27- import com .google .cloud .tpu .v2 .TpuClient ;
28- import java .io .ByteArrayOutputStream ;
2925import java .io .IOException ;
30- import java .io .PrintStream ;
3126import java .util .UUID ;
3227import java .util .concurrent .ExecutionException ;
3328import java .util .concurrent .TimeUnit ;
34- import org .junit .Assert ;
3529import org .junit .jupiter .api .AfterAll ;
3630import org .junit .jupiter .api .Assertions ;
3731import org .junit .jupiter .api .BeforeAll ;
4438import org .junit .runners .JUnit4 ;
4539
4640@ RunWith (JUnit4 .class )
47- @ Timeout (value = 25 , unit = TimeUnit .MINUTES )
41+ @ Timeout (value = 15 , unit = TimeUnit .MINUTES )
4842@ TestMethodOrder (MethodOrderer . OrderAnnotation . class )
4943public class TpuVmIT {
5044 private static final String PROJECT_ID = System .getenv ("GOOGLE_CLOUD_PROJECT" );
51- private static final String ZONE = "europe-west4 -a" ;
45+ private static final String ZONE = "us-central1 -a" ;
5246 static String javaVersion = System .getProperty ("java.version" ).substring (0 , 2 );
5347 private static final String NODE_NAME = "test-tpu-" + javaVersion + "-"
5448 + UUID .randomUUID ().toString ().substring (0 , 8 );
55- private static final String TPU_TYPE = "v2 -8" ;
49+ private static final String TPU_TYPE = "v3 -8" ;
5650 private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1" ;
5751 private static final String NODE_PATH_NAME =
5852 String .format ("projects/%s/locations/%s/nodes/%s" , PROJECT_ID , ZONE , NODE_NAME );
@@ -85,14 +79,13 @@ public static void cleanup() throws Exception {
8579 @ Test
8680 @ Order (1 )
8781 public void testCreateTpuVm () throws IOException , ExecutionException , InterruptedException {
88- final PrintStream out = System .out ;
89- ByteArrayOutputStream stdOut = new ByteArrayOutputStream ();
90- System .setOut (new PrintStream (stdOut ));
91- CreateTpuVm .createTpuVm (PROJECT_ID , ZONE , NODE_NAME , TPU_TYPE , TPU_SOFTWARE_VERSION );
9282
93- assertThat (stdOut .toString ()).contains ("TPU VM created: " + NODE_PATH_NAME );
94- stdOut .close ();
95- System .setOut (out );
83+ Node node = CreateTpuVm .createTpuVm (
84+ PROJECT_ID , ZONE , NODE_NAME , TPU_TYPE , TPU_SOFTWARE_VERSION );
85+
86+ assertNotNull (node );
87+ assertThat (node .getName ().equals (NODE_NAME ));
88+ assertThat (node .getAcceleratorType ().equals (TPU_TYPE ));
9689 }
9790
9891 @ Test
@@ -103,33 +96,4 @@ public void testGetTpuVm() throws IOException {
10396 assertNotNull (node );
10497 assertThat (node .getName ()).isEqualTo (NODE_PATH_NAME );
10598 }
106-
107- @ Test
108- @ Order (2 )
109- public void testListTpuVm () throws IOException {
110- TpuClient .ListNodesPagedResponse nodesList = ListTpuVms .listTpuVms (PROJECT_ID , ZONE );
111-
112- assertNotNull (nodesList );
113- for (Node node : nodesList .iterateAll ()) {
114- Assert .assertTrue (node .getName ().contains ("test-tpu" ));
115- }
116- }
117-
118- @ Test
119- @ Order (2 )
120- public void testStopTpuVm () throws IOException , ExecutionException , InterruptedException {
121- StopTpuVm .stopTpuVm (PROJECT_ID , ZONE , NODE_NAME );
122- Node node = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
123-
124- assertThat (node .getState ()).isEqualTo (STOPPED );
125- }
126-
127- @ Test
128- @ Order (3 )
129- public void testStartTpuVm () throws IOException , ExecutionException , InterruptedException {
130- StartTpuVm .startTpuVm (PROJECT_ID , ZONE , NODE_NAME );
131- Node node = GetTpuVm .getTpuVm (PROJECT_ID , ZONE , NODE_NAME );
132-
133- assertThat (node .getState ()).isEqualTo (READY );
134- }
13599}
0 commit comments