1+ /*
2+ * Copyright 2024 Google LLC
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * http://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+
17+ package tpu ;
18+
19+ //[START tpu_vm_create_topology]
20+ import com .google .api .gax .longrunning .OperationTimedPollAlgorithm ;
21+ import com .google .api .gax .retrying .RetrySettings ;
22+ import com .google .cloud .tpu .v2 .AcceleratorConfig ;
23+ import com .google .cloud .tpu .v2 .AcceleratorConfig .Type ;
24+ import com .google .cloud .tpu .v2 .CreateNodeRequest ;
25+ import com .google .cloud .tpu .v2 .Node ;
26+ import com .google .cloud .tpu .v2 .TpuClient ;
27+ import com .google .cloud .tpu .v2 .TpuSettings ;
28+ import java .io .IOException ;
29+ import java .util .concurrent .ExecutionException ;
30+ import org .threeten .bp .Duration ;
31+
32+ public class CreateTpuWithTopologyFlag {
33+
34+ public static void main (String [] args )
35+ throws IOException , ExecutionException , InterruptedException {
36+ // TODO(developer): Replace these variables before running the sample.
37+ // Project ID or project number of the Google Cloud project you want to create a node.
38+ String projectId = "YOUR_PROJECT_ID" ;
39+ // The zone in which to create the TPU.
40+ // For more information about supported TPU types for specific zones,
41+ // see https://cloud.google.com/tpu/docs/regions-zones
42+ String zone = "europe-west4-a" ;
43+ // The name for your TPU.
44+ String nodeName = "YOUR_TPY_NAME" ;
45+ // The version of the Cloud TPU you want to create.
46+ // Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7
47+ Type tpuVersion = AcceleratorConfig .Type .V2 ;
48+ // Software version that specifies the version of the TPU runtime to install.
49+ // For more information, see https://cloud.google.com/tpu/docs/runtimes
50+ String tpuSoftwareVersion = "tpu-vm-tf-2.17.0-pod-pjrt" ;
51+ // The physical topology of your TPU slice.
52+ // For more information about topology for each TPU version,
53+ // see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
54+ String topology = "2x2" ;
55+
56+ createTpuWithTopologyFlag (projectId , zone , nodeName , tpuVersion , tpuSoftwareVersion , topology );
57+ }
58+
59+ // Creates a TPU VM with the specified name, zone, version and topology.
60+ public static Node createTpuWithTopologyFlag (String projectId , String zone , String nodeName ,
61+ Type tpuVersion , String tpuSoftwareVersion , String topology )
62+ throws IOException , ExecutionException , InterruptedException {
63+ // With these settings the client library handles the Operation's polling mechanism
64+ // and prevent CancellationException error
65+ TpuSettings .Builder clientSettings =
66+ TpuSettings .newBuilder ();
67+ clientSettings
68+ .createNodeOperationSettings ()
69+ .setPollingAlgorithm (
70+ OperationTimedPollAlgorithm .create (
71+ RetrySettings .newBuilder ()
72+ .setInitialRetryDelay (Duration .ofMillis (5000L ))
73+ .setRetryDelayMultiplier (1.5 )
74+ .setMaxRetryDelay (Duration .ofMillis (45000L ))
75+ .setInitialRpcTimeout (Duration .ZERO )
76+ .setRpcTimeoutMultiplier (1.0 )
77+ .setMaxRpcTimeout (Duration .ZERO )
78+ .setTotalTimeout (Duration .ofHours (24L ))
79+ .build ()));
80+
81+ // Initialize client that will be used to send requests. This client only needs to be created
82+ // once, and can be reused for multiple requests.
83+ try (TpuClient tpuClient = TpuClient .create (clientSettings .build ())) {
84+ String parent = String .format ("projects/%s/locations/%s" , projectId , zone );
85+
86+ Node tpuVm =
87+ Node .newBuilder ()
88+ .setName (nodeName )
89+ .setAcceleratorConfig (Node .newBuilder ()
90+ .getAcceleratorConfigBuilder ()
91+ .setType (tpuVersion )
92+ .setTopology (topology )
93+ .build ())
94+ .setRuntimeVersion (tpuSoftwareVersion )
95+ .build ();
96+
97+ CreateNodeRequest request =
98+ CreateNodeRequest .newBuilder ()
99+ .setParent (parent )
100+ .setNodeId (nodeName )
101+ .setNode (tpuVm )
102+ .build ();
103+
104+ Node response = tpuClient .createNodeAsync (request ).get ();
105+ System .out .printf ("TPU VM created: %s\n " , response .getName ());
106+ return response ;
107+ }
108+ }
109+ }
110+ //[END tpu_vm_create_topology]
0 commit comments