Skip to content

Commit d9b2f9c

Browse files
Create draft create TPU with spot Sample (#12705)
1 parent 6ab96dc commit d9b2f9c

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tpu/create_tpu_spot.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
from google.cloud.tpu_v2 import Node
17+
18+
19+
def create_tpu_with_spot(
20+
project_id: str,
21+
zone: str,
22+
tpu_name: str,
23+
tpu_type: str = "v2-8",
24+
runtime_version: str = "tpu-vm-tf-2.17.0-pjrt",
25+
) -> Node:
26+
"""Creates a Cloud TPU node.
27+
Args:
28+
project_id (str): The ID of the Google Cloud project.
29+
zone (str): The zone where the TPU node will be created.
30+
tpu_name (str): The name of the TPU node.
31+
tpu_type (str, optional): The type of TPU to create.
32+
runtime_version (str, optional): The runtime version for the TPU.
33+
Returns:
34+
Node: The created TPU node.
35+
"""
36+
# [START tpu_vm_create_spot]
37+
from google.cloud import tpu_v2
38+
39+
# TODO (developer): Update and un-comment below lines
40+
# project_id = "your-project-id"
41+
# zone = "us-central1-b"
42+
# tpu_name = "tpu-name"
43+
# tpu_type = "v2-8"
44+
# runtime_version = "tpu-vm-tf-2.17.0-pjrt"
45+
46+
# Create a TPU node
47+
node = tpu_v2.Node()
48+
node.accelerator_type = tpu_type
49+
# To see available runtime version use command:
50+
# gcloud compute tpus versions list --zone={ZONE}
51+
node.runtime_version = runtime_version
52+
53+
# TODO: Wait for update of library to change preemptible to spot=True
54+
node.scheduling_config = tpu_v2.SchedulingConfig(preemptible=True)
55+
56+
request = tpu_v2.CreateNodeRequest(
57+
parent=f"projects/{project_id}/locations/{zone}",
58+
node_id=tpu_name,
59+
node=node,
60+
)
61+
62+
client = tpu_v2.TpuClient()
63+
operation = client.create_node(request=request)
64+
print("Waiting for operation to complete...")
65+
66+
response = operation.result()
67+
68+
print(response.scheduling_config)
69+
# Example response:
70+
# TODO: Update the response to include the scheduling config
71+
72+
# [END tpu_vm_create_spot]
73+
return response
74+
75+
76+
if __name__ == "__main__":
77+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
78+
ZONE = "us-central1-b"
79+
create_tpu_with_spot(PROJECT_ID, ZONE, "tpu-with-spot")

0 commit comments

Comments
 (0)