1+ /*
2+ * Copyright 2024 Google LLC
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * http://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+
17+ package tpu ;
18+
19+ import static com .google .common .truth .Truth .assertThat ;
20+ import static com .google .common .truth .Truth .assertWithMessage ;
21+
22+ import com .google .api .gax .rpc .NotFoundException ;
23+ import java .io .ByteArrayOutputStream ;
24+ import java .io .IOException ;
25+ import java .io .PrintStream ;
26+ import java .util .UUID ;
27+ import java .util .concurrent .ExecutionException ;
28+ import java .util .concurrent .TimeUnit ;
29+ import org .junit .jupiter .api .AfterAll ;
30+ import org .junit .jupiter .api .Assertions ;
31+ import org .junit .jupiter .api .BeforeAll ;
32+ import org .junit .jupiter .api .Test ;
33+ import org .junit .jupiter .api .Timeout ;
34+ import org .junit .runner .RunWith ;
35+ import org .junit .runners .JUnit4 ;
36+
37+ @ RunWith (JUnit4 .class )
38+ @ Timeout (value = 25 , unit = TimeUnit .MINUTES )
39+ public class CreateSpotTpuVmIT {
40+ private static final String PROJECT_ID = System .getenv ("GOOGLE_CLOUD_PROJECT" );
41+ private static final String ZONE = "europe-west4-a" ;
42+ static String javaVersion = System .getProperty ("java.version" ).substring (0 , 2 );
43+ private static final String TPU_VM_NAME = "test-spot-tpu-" + javaVersion + "-"
44+ + UUID .randomUUID ().toString ().substring (0 , 8 );
45+ private static final String ACCELERATOR_TYPE = "v2-8" ;
46+ private static final String VERSION = "tpu-vm-tf-2.14.1" ;
47+ private static final String TPU_VM_PATH_NAME =
48+ String .format ("projects/%s/locations/%s/nodes/%s" , PROJECT_ID , ZONE , TPU_VM_NAME );
49+
50+ public static void requireEnvVar (String envVarName ) {
51+ assertWithMessage (String .format ("Missing environment variable '%s' " , envVarName ))
52+ .that (System .getenv (envVarName )).isNotEmpty ();
53+ }
54+
55+ @ BeforeAll
56+ public static void setUp ()
57+ throws IOException , ExecutionException , InterruptedException {
58+ requireEnvVar ("GOOGLE_APPLICATION_CREDENTIALS" );
59+ requireEnvVar ("GOOGLE_CLOUD_PROJECT" );
60+
61+ // Cleanup existing stale resources.
62+ Util .cleanUpExistingTpu ("test-spot-tpu-" + javaVersion , PROJECT_ID , ZONE );
63+ }
64+
65+ @ AfterAll
66+ public static void cleanup () throws Exception {
67+ DeleteTpuVm .deleteTpuVm (PROJECT_ID , ZONE , TPU_VM_NAME );
68+
69+ // Test that TPUs is deleted
70+ Assertions .assertThrows (
71+ NotFoundException .class ,
72+ () -> GetTpuVm .getTpuVm (PROJECT_ID , ZONE , TPU_VM_NAME ));
73+ }
74+
75+ @ Test
76+ public void testCreateSpotTpuVm () throws IOException , ExecutionException , InterruptedException {
77+ final PrintStream out = System .out ;
78+ ByteArrayOutputStream stdOut = new ByteArrayOutputStream ();
79+ System .setOut (new PrintStream (stdOut ));
80+ CreateSpotTpuVm .createSpotTpuVm (PROJECT_ID , ZONE , TPU_VM_NAME , ACCELERATOR_TYPE , VERSION );
81+
82+ assertThat (stdOut .toString ()).contains ("TPU VM created: " + TPU_VM_PATH_NAME );
83+ stdOut .close ();
84+ System .setOut (out );
85+ }
86+ }
0 commit comments