Skip to content

Commit 3f7c7bc

Browse files
mollyheamazonrsareddy0329
authored andcommitted
Cluster-stack template agnostic change (#245)
* decouple template from src code * remove field validator from SDK pydantic model, fix minor parsing problem with list, update kubernetes_version type from str to float * change type handler from class to module functions, change some public function to private, update unit tests * cluster-stack template agnostic change * update unit tests * update integ test * resolve circular import for cluster_stack * resolve rebase merge conflict * rename to_domain to to_config for cluster_stack * increase timeout for endpoint integ test from 15min to 20min
1 parent 72c14f7 commit 3f7c7bc

File tree

17 files changed

+239
-587
lines changed

17 files changed

+239
-587
lines changed

hyperpod-cluster-stack-template/hyperpod_cluster_stack_template/v1_0/model.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pydantic import BaseModel, Field, field_validator
22
from typing import Optional, Literal, List, Any, Union
3+
from sagemaker.hyperpod.common.utils import region_to_az_ids
34

45
class ClusterStackBase(BaseModel):
56
resource_name_prefix: Optional[str] = Field("hyp-eks-stack", description="Prefix to be used for all resources. A 4-digit UUID will be added to prefix during submission")
@@ -56,4 +57,77 @@ class ClusterStackBase(BaseModel):
5657
def validate_kubernetes_version(cls, v):
5758
if v is not None:
5859
return str(v)
59-
return v
60+
return v
61+
62+
def to_config(self, region: str = None):
63+
"""Convert CLI model to SDK configuration for cluster stack creation.
64+
65+
Transforms the CLI model instance into a configuration dictionary that can be used
66+
to instantiate the HpClusterStack SDK class. Applies necessary transformations
67+
including AZ configuration, UUID generation, and field restructuring.
68+
69+
Args:
70+
region (str, optional): AWS region for AZ configuration. If provided,
71+
automatically sets availability_zone_ids and fsx_availability_zone_id
72+
when not already specified.
73+
74+
Returns:
75+
dict: Configuration dictionary ready for HpClusterStack instantiation.
76+
Contains all transformed parameters with defaults applied.
77+
78+
Example:
79+
>>> cli_model = ClusterStackBase(hyperpod_cluster_name="my-cluster")
80+
>>> config = cli_model.to_config(region="us-west-2")
81+
>>> sdk_instance = HpClusterStack(**config)
82+
"""
83+
import uuid
84+
85+
# Convert model to dict and apply transformations
86+
config = self.model_dump(exclude_none=True)
87+
88+
# Prepare CFN arrays from numbered fields
89+
instance_group_settings = []
90+
rig_settings = []
91+
for i in range(1, 21):
92+
ig_key = f'instance_group_settings{i}'
93+
rig_key = f'rig_settings{i}'
94+
if ig_key in config:
95+
instance_group_settings.append(config.pop(ig_key))
96+
if rig_key in config:
97+
rig_settings.append(config.pop(rig_key))
98+
99+
# Add arrays to config
100+
if instance_group_settings:
101+
config['instance_group_settings'] = instance_group_settings
102+
if rig_settings:
103+
config['rig_settings'] = rig_settings
104+
105+
# Add default AZ configuration if not provided
106+
if region and (not config.get('availability_zone_ids') or not config.get('fsx_availability_zone_id')):
107+
all_az_ids = region_to_az_ids(region)
108+
default_az_config = {
109+
'availability_zone_ids': all_az_ids[:2], # First 2 AZs
110+
'fsx_availability_zone_id': all_az_ids[0] # First AZ
111+
}
112+
if not config.get('availability_zone_ids'):
113+
config['availability_zone_ids'] = default_az_config['availability_zone_ids']
114+
if not config.get('fsx_availability_zone_id'):
115+
config['fsx_availability_zone_id'] = default_az_config['fsx_availability_zone_id']
116+
117+
# Append 4-digit UUID to resource_name_prefix
118+
if config.get('resource_name_prefix'):
119+
config['resource_name_prefix'] = f"{config['resource_name_prefix']}-{str(uuid.uuid4())[:4]}"
120+
121+
# Set fixed defaults
122+
defaults = {
123+
'custom_bucket_name': 'sagemaker-hyperpod-cluster-stack-bucket',
124+
'github_raw_url': 'https://raw.githubusercontent.com/aws-samples/awsome-distributed-training/refs/heads/main/1.architectures/7.sagemaker-hyperpod-eks/LifecycleScripts/base-config/on_create.sh',
125+
'helm_repo_url': 'https://github.com/aws/sagemaker-hyperpod-cli.git',
126+
'helm_repo_path': 'helm_chart/HyperPodHelmChart'
127+
}
128+
129+
for key, default_value in defaults.items():
130+
if key not in config:
131+
config[key] = default_value
132+
133+
return config

src/sagemaker/hyperpod/cli/commands/cluster_stack.py

Lines changed: 33 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from sagemaker.hyperpod.common.telemetry.constants import Feature
1919
from sagemaker.hyperpod.common.utils import setup_logging
2020
from sagemaker.hyperpod.cli.utils import convert_datetimes
21+
from sagemaker.hyperpod.cli.init_utils import _filter_cli_metadata_fields
22+
from sagemaker.hyperpod.cli.init_utils import load_config
23+
from sagemaker.hyperpod.cli.constants.init_constants import TEMPLATES
24+
from pathlib import Path
2125
from sagemaker.hyperpod.cli.cluster_stack_utils import (
2226
StackNotFoundError,
2327
delete_stack_with_confirmation
@@ -67,86 +71,42 @@ def create_cluster_stack(config_file, region, debug):
6771
# Create with debug logging
6872
hyp create hyp-cluster cluster-config.yaml my-stack-name --debug
6973
"""
70-
create_cluster_stack_helper(config_file, region, debug)
71-
72-
def create_cluster_stack_helper(config_file: str, region: Optional[str] = None, debug: bool = False) -> None:
73-
"""Helper function to create a HyperPod cluster stack.
74-
75-
**Parameters:**
76-
77-
.. list-table::
78-
:header-rows: 1
79-
:widths: 20 20 60
80-
81-
* - Parameter
82-
- Type
83-
- Description
84-
* - config_file
85-
- str
86-
- Path to the YAML configuration file containing cluster stack settings
87-
* - region
88-
- str, optional
89-
- AWS region where the cluster stack will be created
90-
* - debug
91-
- bool
92-
- Enable debug logging for detailed error information
93-
94-
**Raises:**
95-
96-
ClickException: When cluster stack creation fails or configuration is invalid
97-
"""
9874
try:
9975
# Validate the config file path
10076
if not os.path.exists(config_file):
10177
logger.error(f"Config file not found: {config_file}")
10278
return
10379

104-
# Load the configuration from the YAML file
105-
import yaml
106-
import uuid
107-
with open(config_file, 'r') as f:
108-
config_data = yaml.safe_load(f)
109-
110-
# Filter out template and namespace fields
111-
filtered_config = {}
112-
for k, v in config_data.items():
113-
if k not in ('template', 'namespace') and v is not None:
114-
# Append 4-digit UUID to resource_name_prefix
115-
if k == 'resource_name_prefix' and v:
116-
v = f"{v}-{str(uuid.uuid4())[:4]}"
117-
filtered_config[k] = v
118-
119-
# Create the HpClusterStack object
120-
# Ensure fixed defaults are always set
121-
if 'custom_bucket_name' not in filtered_config:
122-
filtered_config['custom_bucket_name'] = 'sagemaker-hyperpod-cluster-stack-bucket'
123-
if 'github_raw_url' not in filtered_config:
124-
filtered_config['github_raw_url'] = 'https://raw.githubusercontent.com/aws-samples/awsome-distributed-training/refs/heads/main/1.architectures/7.sagemaker-hyperpod-eks/LifecycleScripts/base-config/on_create.sh'
125-
if 'helm_repo_url' not in filtered_config:
126-
filtered_config['helm_repo_url'] = 'https://github.com/aws/sagemaker-hyperpod-cli.git'
127-
if 'helm_repo_path' not in filtered_config:
128-
filtered_config['helm_repo_path'] = 'helm_chart/HyperPodHelmChart'
129-
130-
cluster_stack = HpClusterStack(**filtered_config)
80+
# Load config to get template and version
81+
82+
config_dir = Path(config_file).parent
83+
data, template, version = load_config(config_dir)
84+
85+
# Get model from registry
86+
registry = TEMPLATES[template]["registry"]
87+
model_class = registry.get(str(version))
88+
89+
if model_class:
90+
# Filter out CLI metadata fields
91+
filtered_config = _filter_cli_metadata_fields(data)
13192

132-
# Log the configuration
133-
logger.info("Creating HyperPod cluster stack with the following configuration:")
134-
for key, value in filtered_config.items():
135-
if value is not None:
136-
logger.info(f" {key}: {value}")
93+
# Create model instance and domain
94+
model_instance = model_class(**filtered_config)
95+
config = model_instance.to_config(region=region)
13796

138-
# Create the cluster stack
139-
stack_id = cluster_stack.create(region)
97+
# Create the cluster stack
98+
stack_id = HpClusterStack(**config).create(region)
14099

141-
logger.info(f"Stack creation initiated successfully with ID: {stack_id}")
142-
logger.info("You can monitor the stack creation in the AWS CloudFormation console.")
100+
logger.info(f"Stack creation initiated successfully with ID: {stack_id}")
101+
logger.info("You can monitor the stack creation in the AWS CloudFormation console.")
143102

144103
except Exception as e:
145104
logger.error(f"Failed to create cluster stack: {e}")
146105
if debug:
147106
logger.exception("Detailed error information:")
148107
raise click.ClickException(str(e))
149108

109+
150110
@click.command("cluster-stack")
151111
@click.argument("stack-name", required=True)
152112
@click.option("--region", help="AWS region")
@@ -223,6 +183,7 @@ def describe_cluster_stack(stack_name: str, debug: bool, region: str) -> None:
223183

224184
raise click.ClickException(str(e))
225185

186+
226187
@click.command("cluster-stack")
227188
@click.option("--region", help="AWS region")
228189
@click.option("--debug", is_flag=True, help="Enable debug logging")
@@ -294,10 +255,11 @@ def list_cluster_stacks(region, debug, status):
294255

295256
raise click.ClickException(str(e))
296257

258+
297259
@click.command("cluster-stack")
298260
@click.argument("stack-name", required=True)
299261
@click.option("--retain-resources", help="Comma-separated list of logical resource IDs to retain during deletion (only works on DELETE_FAILED stacks). Resource names are shown in failed deletion output, or use AWS CLI: 'aws cloudformation list-stack-resources --stack-name STACK_NAME --region REGION'")
300-
@click.option("--region", required=True, help="AWS region (required)")
262+
@click.option("--region", required=True, help="AWS region")
301263
@click.option("--debug", is_flag=True, help="Enable debug logging")
302264
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "delete_cluster_stack_cli")
303265
def delete_cluster_stack(stack_name: str, retain_resources: str, region: str, debug: bool) -> None:
@@ -314,6 +276,10 @@ def delete_cluster_stack(stack_name: str, retain_resources: str, region: str, de
314276
# Delete a cluster stack
315277
hyp delete cluster-stack my-stack-name --region us-west-2
316278
279+
# Delete with retained resources (only works on DELETE_FAILED stacks)
280+
hyp delete cluster-stack my-stack-name --retain-resources S3Bucket-TrainingData,EFSFileSystem-Models --region us-west-2
281+
hyp delete cluster-stack my-stack-name --region us-west-2
282+
317283
# Delete with retained resources (only works on DELETE_FAILED stacks)
318284
hyp delete cluster-stack my-stack-name --retain-resources S3Bucket-TrainingData,EFSFileSystem-Models --region us-west-2
319285
"""
@@ -329,7 +295,7 @@ def delete_cluster_stack(stack_name: str, retain_resources: str, region: str, de
329295
confirm_callback=lambda msg: click.confirm("Continue?", default=False),
330296
success_callback=lambda msg: click.echo(f"✓ {msg}")
331297
)
332-
298+
333299
except StackNotFoundError:
334300
click.secho(f"❌ Stack '{stack_name}' not found", fg='red')
335301
except click.ClickException:
@@ -341,6 +307,7 @@ def delete_cluster_stack(stack_name: str, retain_resources: str, region: str, de
341307
logger.exception("Detailed error information:")
342308
raise click.ClickException(str(e))
343309

310+
344311
@click.command("cluster")
345312
@click.option("--cluster-name", required=True, help="The name of the cluster to update")
346313
@click.option("--instance-groups", help="Instance Groups JSON string")

src/sagemaker/hyperpod/cli/commands/init.py

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
from sagemaker.hyperpod.cli.constants.init_constants import (
99
USAGE_GUIDE_TEXT_CFN,
1010
USAGE_GUIDE_TEXT_CRD,
11-
CFN,
12-
CRD
11+
CFN
1312
)
14-
from sagemaker.hyperpod.common.config import Metadata
15-
from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob
1613
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack
1714
from sagemaker.hyperpod.cli.init_utils import (
1815
generate_click_command,
@@ -25,8 +22,7 @@
2522
display_validation_results,
2623
build_config_from_schema,
2724
save_template,
28-
get_default_version_for_template,
29-
add_default_az_ids_to_config,
25+
get_default_version_for_template
3026
)
3127
from sagemaker.hyperpod.common.utils import get_aws_default_region
3228

@@ -265,7 +261,7 @@ def validate():
265261

266262

267263
@click.command(name="_default_create")
268-
@click.option("--region", "-r", default=None, help="Region, default to your region in aws configure")
264+
@click.option("--region", "-r", default=None, help="Region to create cluster stack for, default to your region in aws configure. Not available for other templates.")
269265
def _default_create(region):
270266
"""
271267
Validate configuration and render template files for deployment.
@@ -300,6 +296,11 @@ def _default_create(region):
300296
# 1) Load config to determine template type
301297
data, template, version = load_config_and_validate(dir_path)
302298

299+
# Check if region flag is used for non-cluster-stack templates
300+
if region and template != "cluster-stack":
301+
click.secho(f"❌ --region flag is only available for cluster-stack template, not for {template}.", fg="red")
302+
sys.exit(1)
303+
303304
# 2) Determine correct jinja file based on template type
304305
info = TEMPLATES[template]
305306
schema_type = info["schema_type"]
@@ -327,27 +328,7 @@ def _default_create(region):
327328
try:
328329
template_source = jinja_file.read_text()
329330
tpl = Template(template_source)
330-
331-
# For CFN templates, prepare arrays for Jinja template
332-
if schema_type == CFN:
333-
# Prepare instance_group_settings array
334-
instance_group_settings = []
335-
rig_settings = []
336-
for i in range(1, 21):
337-
ig_key = f'instance_group_settings{i}'
338-
rig_key = f'rig_settings{i}'
339-
if ig_key in data:
340-
instance_group_settings.append(data[ig_key])
341-
if rig_key in data:
342-
rig_settings.append(data[rig_key])
343-
344-
# Add arrays to template context
345-
template_data = dict(data)
346-
template_data['instance_group_settings'] = instance_group_settings
347-
template_data['rig_settings'] = rig_settings
348-
rendered = tpl.render(**template_data)
349-
else:
350-
rendered = tpl.render(**data)
331+
rendered = tpl.render(**data)
351332
except Exception as e:
352333
click.secho(f"❌ Failed to render template: {e}", fg="red")
353334
sys.exit(1)
@@ -375,27 +356,26 @@ def _default_create(region):
375356
region = get_aws_default_region()
376357
click.secho(f"Submitting to default region: {region}.", fg="yellow")
377358

378-
if schema_type == CFN:
379-
add_default_az_ids_to_config(out_dir, region)
380-
381-
from sagemaker.hyperpod.cli.commands.cluster_stack import create_cluster_stack_helper
382-
create_cluster_stack_helper(config_file=f"{out_dir}/config.yaml",
383-
region=region)
384-
else:
385-
dir_path = Path(".").resolve()
386-
data, template, version = load_config(dir_path)
387-
namespace = data.get("namespace", "default")
388-
registry = TEMPLATES[template]["registry"]
389-
model = registry.get(str(version))
390-
if model:
391-
# Filter out CLI metadata fields before passing to model
392-
from sagemaker.hyperpod.cli.init_utils import _filter_cli_metadata_fields
393-
filtered_config = _filter_cli_metadata_fields(data)
394-
flat = model(**filtered_config)
359+
# Unified pattern for all templates
360+
dir_path = Path(".").resolve()
361+
data, template, version = load_config(dir_path)
362+
registry = TEMPLATES[template]["registry"]
363+
model = registry.get(str(version))
364+
if model:
365+
# Filter out CLI metadata fields before passing to model
366+
from sagemaker.hyperpod.cli.init_utils import _filter_cli_metadata_fields
367+
filtered_config = _filter_cli_metadata_fields(data)
368+
flat = model(**filtered_config)
369+
370+
# Pass region to to_domain for cluster stack template
371+
if template == "cluster-stack":
372+
config = flat.to_config(region=region)
373+
HpClusterStack(**config).create(region)
374+
else:
395375
domain = flat.to_domain()
396376
domain.create()
397377

398378

399379
except Exception as e:
400380
click.secho(f"❌ Failed to submit the command: {e}", fg="red")
401-
sys.exit(1)
381+
sys.exit(1)

0 commit comments

Comments
 (0)