Skip to content

Commit e4d6749

Browse files
Add checkpoint configuration file for MTC (#465)
1 parent ad5efe4 commit e4d6749

File tree

5 files changed

+167
-1
lines changed

5 files changed

+167
-1
lines changed

src/xpk/commands/cluster.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
run_gke_node_pool_create_command,
5959
)
6060
from ..core.ray import install_ray_cluster
61+
from ..core.mtc import install_mtc_on_cluster
6162
from ..core.resources import create_cluster_configmaps
6263
from ..core.storage import install_storage_crd
6364
from ..core.system_characteristics import (
@@ -279,6 +280,12 @@ def cluster_create(args) -> None:
279280
xpk_print('Installation of RayCluster failed.')
280281
xpk_exit(return_code)
281282

283+
if hasattr(args, 'enable_mtc') and args.enable_mtc:
284+
return_code = install_mtc_on_cluster(args, system)
285+
if return_code != 0:
286+
xpk_print('Installation of MTC failed.')
287+
xpk_exit(return_code)
288+
282289
xpk_print('GKE commands done! Resources are created.')
283290
xpk_print(
284291
'See your GKE Cluster here:'
@@ -815,6 +822,7 @@ def run_gke_cluster_create_command(
815822
addons = []
816823
if args.enable_gcsfuse_csi_driver:
817824
addons.append('GcsFuseCsiDriver')
825+
818826
if args.enable_gcpfilestore_csi_driver:
819827
addons.append('GcpFilestoreCsiDriver')
820828

@@ -824,6 +832,9 @@ def run_gke_cluster_create_command(
824832
if args.enable_pd_csi_driver:
825833
addons.append('GcePersistentDiskCsiDriver')
826834

835+
if hasattr(args, 'enable_mtc') and args.enable_mtc:
836+
addons.append('HighScaleCheckpointing')
837+
827838
if len(addons) > 0:
828839
addons_str = ','.join(addons)
829840
command += f' --addons={addons_str}'

src/xpk/core/mtc.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from ..utils.console import xpk_exit, xpk_print
18+
from ..utils import templates
19+
from ..utils.kubectl import apply_kubectl_manifest
20+
from ..core.cluster import setup_k8s_env
21+
22+
MTC_CPC_PATH = "/../templates/mtc-cpc.yaml"
23+
24+
25+
def create_mtc_cpc(
26+
mtc_gcs_bucket: str,
27+
mtc_machine_type: str,
28+
mtc_toleration_key: str,
29+
mtc_ramdisk_size: str,
30+
) -> dict:
31+
data = templates.load(MTC_CPC_PATH)
32+
33+
data["spec"]["cloudStorageBucketName"] = mtc_gcs_bucket
34+
data["spec"]["nodeSelector"][
35+
"node.kubernetes.io/instance-type"
36+
] = mtc_machine_type
37+
data["spec"]["tolerations"][0]["key"] = mtc_toleration_key
38+
data["spec"]["inMemoryVolumeSize"] = mtc_ramdisk_size
39+
40+
return data
41+
42+
43+
def install_mtc_on_cluster(args, system) -> int:
44+
"""Install MTC on the cluster
45+
46+
Args:
47+
args: user provided arguments for running the command.
48+
49+
Returns:
50+
return code of the command.
51+
"""
52+
if args.mtc_gcs_bucket is None:
53+
xpk_print("MTC GCS bucket is required.")
54+
xpk_exit(1)
55+
if args.mtc_gcs_bucket.startswith("gs://"):
56+
args.mtc_gcs_bucket = args.mtc_gcs_bucket.replace("gs://", "")
57+
58+
if args.mtc_ramdisk_size is None:
59+
xpk_print("MTC ramdisk size is required.")
60+
xpk_exit(1)
61+
62+
if args.mtc_toleration_key is None:
63+
args.mtc_toleration_key = "google.com/tpu"
64+
65+
mtc_checkpoint_configuration_crd_data = create_mtc_cpc(
66+
args.mtc_gcs_bucket,
67+
system.gce_machine_type,
68+
args.mtc_toleration_key,
69+
args.mtc_ramdisk_size,
70+
)
71+
xpk_print("Applying MTC Checkpoint Configuration")
72+
k8s_api_client = setup_k8s_env(args)
73+
return_code = apply_kubectl_manifest(
74+
k8s_api_client, [mtc_checkpoint_configuration_crd_data]
75+
)
76+
77+
return return_code

src/xpk/parser/cluster.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser):
189189
cluster_create_tensorboard_arguments
190190
)
191191

192+
### MTC arguments specific to "cluster create"
193+
cluster_create_mtc_arguments = cluster_create_parser.add_argument_group(
194+
'Optional MTC Arguments',
195+
'Arguments for configuring MTC in cluster create.',
196+
)
197+
add_shared_cluster_create_mtc_arguments(cluster_create_mtc_arguments)
192198
cluster_create_parser.set_defaults(func=cluster_create)
193199

194200

@@ -244,6 +250,14 @@ def set_cluster_create_pathways_parser(
244250
cluster_create_pathways_tensorboard_arguments
245251
)
246252

253+
### MTC arguments specific to "cluster create"
254+
cluster_create_mtc_arguments = (
255+
cluster_create_pathways_parser.add_argument_group(
256+
'Optional MTC Arguments',
257+
'Arguments for configuring MTC in cluster create.',
258+
)
259+
)
260+
add_shared_cluster_create_mtc_arguments(cluster_create_mtc_arguments)
247261
cluster_create_pathways_parser.set_defaults(func=cluster_create_pathways)
248262

249263

@@ -313,6 +327,12 @@ def set_cluster_create_ray_parser(cluster_create_ray_parser: ArgumentParser):
313327
cluster_create_ray_tensorboard_arguments
314328
)
315329

330+
### MTC arguments specific to "cluster create"
331+
cluster_create_mtc_arguments = cluster_create_ray_parser.add_argument_group(
332+
'Optional MTC Arguments',
333+
'Arguments for configuring MTC in cluster create.',
334+
)
335+
add_shared_cluster_create_mtc_arguments(cluster_create_mtc_arguments)
316336
cluster_create_ray_parser.set_defaults(func=cluster_create_ray_cluster)
317337

318338

@@ -706,3 +726,43 @@ def add_shared_cluster_create_capacity_arguments(parser: ArgumentParser):
706726
' See `--reservation` or `--on-demand` for other capacity types.'
707727
),
708728
)
729+
730+
731+
def add_shared_cluster_create_mtc_arguments(parser: ArgumentParser):
732+
"""Add shared Multi-tier Checkpointing arguments in cluster create and Pathways cluster create.
733+
734+
Args:
735+
List of cluster create MTC arguments parsers
736+
"""
737+
parser.add_argument(
738+
'--enable-mtc',
739+
action='store_true',
740+
help='Enable MTC on the cluster.',
741+
)
742+
parser.add_argument(
743+
'--mtc-ramdisk-size',
744+
type=str,
745+
default=None,
746+
help=(
747+
'(Required if --enable-mtc is true) The size of the RAM disk to be'
748+
' used for multi-tier checkpointing. e.g. "64Mi" '
749+
),
750+
)
751+
parser.add_argument(
752+
'--mtc-gcs-bucket',
753+
type=str,
754+
default=None,
755+
help=(
756+
'(Required if --enable-mtc is true) The GCS bucket to be used for'
757+
' multi-tier checkpointing.'
758+
),
759+
)
760+
parser.add_argument(
761+
'--mtc-toleration-key',
762+
type=str,
763+
default=None,
764+
help=(
765+
'(Optional) The tolerance key to be used for multi-tier'
766+
' checkpointing. By default, it is set to "google.com/tpu".'
767+
),
768+
)

src/xpk/templates/mtc-cpc.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
apiVersion: checkpointing.gke.io/v1
2+
kind: CheckpointConfiguration
3+
metadata:
4+
name: my-checkpointconfiguration
5+
spec:
6+
cloudStorageBucketName:
7+
# This field is optional
8+
nodeSelector:
9+
node.kubernetes.io/instance-type:
10+
# This field is optional
11+
tolerations:
12+
- key:
13+
operator: Exists
14+
effect: NoSchedule
15+
inMemoryVolumeSize:

src/xpk/utils/kubectl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from .console import xpk_print
2121

2222

23-
def apply_kubectl_manifest(client, manifest):
23+
def apply_kubectl_manifest(client, manifest) -> int:
2424
xpk_print('Applying manifest')
2525
dynamic_client = DynamicClient(client)
2626

27+
status_code = 0
2728
for obj in manifest:
2829
api_version = obj['apiVersion']
2930
kind = obj['kind']
@@ -55,3 +56,5 @@ def apply_kubectl_manifest(client, manifest):
5556
)
5657
else:
5758
xpk_print(f'Error applying {kind}: {e}')
59+
status_code = 1
60+
return status_code

0 commit comments

Comments
 (0)