Skip to content

Commit 79b0342

Browse files
Update jinja template handling logic for inference and training (#279)
* Update jinja template handling logic for inference and training, cluster logic remaining for discussion * test inference and training all parameters * minor change to fix integ * move create_from_k8s_yaml to init_utils and init_constants, reuse create for create_from_dict for inference SDK, update template kind * Fix unit test * update create and create_from_dict for inference to share internal_create, revert unit test changes fix
1 parent c5edf2d commit 79b0342

File tree

19 files changed

+458
-323
lines changed

19 files changed

+458
-323
lines changed

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from pydantic import BaseModel, Field, model_validator, ConfigDict
1414
from typing import Optional, List, Dict, Union, Literal
15+
import yaml
1516

1617
from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
1718
Metrics,
@@ -367,4 +368,4 @@ def to_domain(self) -> HPEndpoint:
367368
worker=worker,
368369
invocation_endpoint=self.invocation_endpoint,
369370
auto_scaling_spec=auto_scaling_spec
370-
)
371+
)
Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,88 @@
1-
TEMPLATE_CONTENT = """### Please keep template file unchanged ###
2-
apiVersion: hyperpod.sagemaker.aws/v1
3-
kind: HPEndpoint
1+
TEMPLATE_CONTENT = """
2+
apiVersion: inference.sagemaker.aws.amazon.com/v1alpha1
3+
kind: InferenceEndpointConfig
44
metadata:
5-
name: "{{ endpoint_name }}"
6-
namespace: "{{ namespace }}"
5+
name: {{ metadata_name or endpoint_name }}
6+
namespace: {{ namespace }}
77
spec:
8-
instanceType: "{{ instance_type }}"
9-
modelName: "{{ model_name }}"
10-
{% if model_version is not none %} modelVersion: "{{ model_version }}"
11-
{% endif %}
12-
env:
13-
{% if env %}
14-
{% for key, val in env.items() %} - name: "{{ key }}"
15-
value: "{{ val }}"
16-
{% endfor %}{% else %} []
17-
{% endif %}
8+
endpointName: {{ endpoint_name }}
9+
instanceType: {{ instance_type }}
10+
modelName: {{ model_name }}
11+
modelVersion: {{ model_version or "" }}
12+
1813
metrics:
19-
enabled: {{ metrics_enabled }}
14+
enabled: {{ metrics_enabled or False }}
15+
2016
modelSourceConfig:
21-
modelSourceType: "{{ model_source_type }}"
22-
{% if model_location is not none %} modelLocation: "{{ model_location }}"
23-
{% endif %} prefetchEnabled: {{ prefetch_enabled }}
24-
{% if model_source_type == "s3" %} s3Storage:
25-
bucketName: "{{ s3_bucket_name }}"
26-
region: "{{ s3_region }}"
27-
{% elif model_source_type == "fsx" %} fsxStorage:
28-
dnsName: "{{ fsx_dns_name }}"
29-
fileSystemId: "{{ fsx_file_system_id }}"
30-
{% if fsx_mount_name is not none %} mountName: "{{ fsx_mount_name }}"
31-
{% endif %}{% endif %}
17+
modelSourceType: {{ model_source_type }}
18+
modelLocation: {{ model_location or "" }}
19+
prefetchEnabled: {{ prefetch_enabled or False }}
20+
{%- if model_source_type == "s3" %}
21+
s3Storage:
22+
bucketName: {{ s3_bucket_name }}
23+
region: {{ s3_region }}
24+
{%- elif model_source_type == "fsx" %}
25+
fsxStorage:
26+
dnsName: {{ fsx_dns_name }}
27+
fileSystemId: {{ fsx_file_system_id }}
28+
mountName: {{ fsx_mount_name or "" }}
29+
{%- endif %}
30+
3231
tlsConfig:
33-
{% if tls_certificate_output_s3_uri is not none %} certificateOutputS3Uri: "{{ tls_certificate_output_s3_uri }}"
34-
{% else %} {}
35-
{% endif %}
32+
tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }}
33+
3634
worker:
37-
image: "{{ image_uri }}"
38-
containerPort: {{ container_port }}
39-
volumeMount:
40-
name: "{{ model_volume_mount_name }}"
41-
mountPath: "{{ model_volume_mount_path }}"
35+
environmentVariables:
36+
{%- if env %}
37+
{%- for key, val in env.items() %}
38+
- name: {{ key }}
39+
value: "{{ val }}"
40+
{%- endfor %}
41+
{%- else %}
42+
[]
43+
{%- endif %}
44+
image: {{ image_uri }}
45+
modelInvocationPort:
46+
containerPort: {{ container_port }}
47+
modelVolumeMount:
48+
name: {{ model_volume_mount_name }}
49+
mountPath: {{ model_volume_mount_path }}
4250
resources:
43-
{% if resources_limits %} limits:
44-
{% for key, val in resources_limits.items() %} {{ key }}: "{{ val }}"
45-
{% endfor %}{% else %} {}
46-
{% endif %}{% if resources_requests %}
51+
{%- if resources_limits %}
52+
limits:
53+
{%- for key, val in resources_limits.items() %}
54+
{{ key }}: {{ val }}
55+
{%- endfor %}
56+
{%- else %}
57+
{}
58+
{%- endif %}
59+
{%- if resources_requests %}
4760
requests:
48-
{% for key, val in resources_requests.items() %} {{ key }}: "{{ val }}"
49-
{% endfor %}{% endif %}
61+
{%- for key, val in resources_requests.items() %}
62+
{{ key }}: {{ val }}
63+
{%- endfor %}
64+
{%- endif %}
65+
5066
autoScalingSpec:
5167
cloudWatchTrigger:
52-
{% if dimensions %} dimensions:
53-
{% for dim_key, dim_val in dimensions.items() %} - name: "{{ dim_key }}"
54-
value: "{{ dim_val }}"
55-
{% endfor %}{% else %} []
56-
{% endif %} metricCollectionPeriod: {{ metric_collection_period }}
68+
{%- if dimensions %}
69+
dimensions:
70+
{%- for dim_key, dim_val in dimensions.items() %}
71+
- name: {{ dim_key }}
72+
value: {{ dim_val }}
73+
{%- endfor %}
74+
{%- endif %}
75+
metricCollectionPeriod: {{ metric_collection_period }}
5776
metricCollectionStartTime: {{ metric_collection_start_time }}
58-
metricName: "{{ metric_name }}"
59-
metricStat: "{{ metric_stat }}"
60-
type: "{{ metric_type }}"
61-
minValue: {{ min_value }}
62-
name: "{{ cloud_watch_trigger_name }}"
63-
namespace: "{{ cloud_watch_trigger_namespace }}"
64-
targetValue: {{ target_value }}
65-
useCachedMetrics: {{ use_cached_metrics }}
66-
invocationEndpoint: "{{ invocation_endpoint }}"
77+
metricName: {{ metric_name or "" }}
78+
metricStat: {{ metric_stat }}
79+
metricType: {{ metric_type }}
80+
minValue: {{ min_value }}
81+
name: {{ cloud_watch_trigger_name or "" }}
82+
namespace: {{ cloud_watch_trigger_namespace or "" }}
83+
targetValue: {{ target_value or "" }}
84+
useCachedMetrics: {{ use_cached_metrics or False }}
85+
86+
invocationEndpoint: {{ invocation_endpoint }}
6787
6888
"""

hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from pydantic import BaseModel, Field, model_validator, ConfigDict
1414
from typing import Optional
15+
import yaml
1516

1617
# reuse the nested types
1718
from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import (
@@ -114,4 +115,4 @@ def to_domain(self) -> HPJumpStartEndpoint:
114115
server=server,
115116
sage_maker_endpoint=sage_ep,
116117
tls_config=tls
117-
)
118+
)

hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/template.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
TEMPLATE_CONTENT = """### Please keep template file unchanged ###
1+
TEMPLATE_CONTENT = """
22
apiVersion: inference.sagemaker.aws.amazon.com/v1alpha1
33
kind: JumpStartModel
44
metadata:
5-
name: {{ model_id }}
5+
name: {{ metadata_name or endpoint_name }}
66
namespace: {{ namespace or "default" }}
77
spec:
88
model:
@@ -14,4 +14,6 @@
1414
name: {{ endpoint_name or "" }}
1515
server:
1616
instanceType: {{ instance_type }}
17+
tlsConfig:
18+
tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }}
1719
"""

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
PersistentVolumeClaim
1515
)
1616
from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob
17-
17+
import yaml
1818

1919
class VolumeConfig(BaseModel):
2020
model_config = ConfigDict(extra="forbid")
@@ -39,7 +39,7 @@ class VolumeConfig(BaseModel):
3939
description="PVC claim name (required for pvc volumes)",
4040
min_length=1
4141
)
42-
read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes")
42+
read_only: Optional[bool] = Field(None, description="Read-only flag for pvc volumes")
4343

4444
@field_validator('mount_path', 'path')
4545
@classmethod
@@ -260,7 +260,7 @@ def build_dict(**kwargs):
260260
elif vol.type == "pvc":
261261
volume_obj = Volumes(name=vol.name, persistent_volume_claim=PersistentVolumeClaim(
262262
claim_name=vol.claim_name,
263-
read_only=vol.read_only == "true" if vol.read_only else False
263+
read_only=vol.read_only if vol.read_only is not None else False
264264
))
265265
volumes.append(volume_obj)
266266

@@ -310,6 +310,7 @@ def build_dict(**kwargs):
310310
result = HyperPodPytorchJob(**job_kwargs)
311311
return result
312312

313+
313314
# Volume-specific type handlers - only override what's needed
314315
def volume_parse_strings(ctx_or_strings, param=None, value=None):
315316
"""Parse volume strings into VolumeConfig objects. Can be used as Click callback."""

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/schema.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,7 @@
5454
"read_only": {
5555
"anyOf": [
5656
{
57-
"enum": [
58-
"true",
59-
"false"
60-
],
61-
"type": "string"
57+
"type": "boolean"
6258
},
6359
{
6460
"type": "null"
Lines changed: 84 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,96 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License"). You
4-
# may not use this file except in compliance with the License. A copy of
5-
# the License is located at
6-
#
7-
# http://aws.amazon.com/apache2.0/
8-
#
9-
# or in the "license" file accompanying this file. This file is
10-
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11-
# ANY KIND, either express or implied. See the License for the specific
12-
# language governing permissions and limitations under the License.
13-
TEMPLATE_CONTENT = """### Please keep template file unchanged ###
1+
TEMPLATE_CONTENT = """
142
apiVersion: sagemaker.amazonaws.com/v1
153
kind: HyperPodPyTorchJob
164
metadata:
17-
name: "{{ job_name }}"
18-
namespace: "{{ namespace }}"
19-
{% if queue_name or priority %} labels:
20-
{% if queue_name %} kueue.x-k8s.io/queue-name: "{{ queue_name }}"
21-
{% endif %}{% if priority %} kueue.x-k8s.io/priority-class: "{{ priority }}"
22-
{% endif %}{% endif %}spec:
23-
{% if tasks_per_node %} nprocPerNode: "{{ tasks_per_node }}"
24-
{% endif %} replicaSpecs:
25-
- name: "pod"
26-
{% if node_count %} replicas: {{ node_count }}
27-
{% endif %} template:
5+
name: {{ job_name }}
6+
namespace: {{ namespace }}
7+
{%- if queue_name or priority %}
8+
labels:
9+
kueue.x-k8s.io/queue-name: {{ queue_name or "" }}
10+
kueue.x-k8s.io/priority-class: {{ priority or "" }}
11+
{%- endif %}
12+
spec:
13+
{%- if tasks_per_node %}
14+
nprocPerNode: "{{ tasks_per_node }}"
15+
{%- endif %}
16+
replicaSpecs:
17+
- name: pod
18+
replicas: {{ node_count or 1 }}
19+
template:
2820
metadata:
29-
name: "{{ job_name }}"
30-
{% if namespace %} namespace: "{{ namespace }}"
31-
{% endif %}{% if queue_name or priority %} labels:
32-
{% if queue_name %} kueue.x-k8s.io/queue-name: "{{ queue_name }}"
33-
{% endif %}{% if priority %} kueue.x-k8s.io/priority-class: "{{ priority }}"
34-
{% endif %}{% endif %} spec:
21+
name: {{ job_name }}
22+
namespace: {{ namespace }}
23+
{%- if queue_name or priority %}
24+
labels:
25+
kueue.x-k8s.io/queue-name: {{ queue_name or "" }}
26+
kueue.x-k8s.io/priority-class: {{ priority or "" }}
27+
{%- endif %}
28+
spec:
3529
containers:
36-
- name: "container-name"
37-
image: "{{ image }}"
38-
{% if pull_policy %} imagePullPolicy: "{{ pull_policy }}"
39-
{% endif %}{% if command %} command: {{ command | tojson }}
40-
{% endif %}{% if args %} args: {{ args | tojson }}
41-
{% endif %}{% if environment %} env:
42-
{% for key, value in environment.items() %} - name: "{{ key }}"
30+
- name: container-name
31+
image: {{ image }}
32+
{%- if pull_policy %}
33+
imagePullPolicy: {{ pull_policy }}
34+
{%- endif %}
35+
{%- if command %}
36+
command: {{ command | tojson }}
37+
{%- endif %}
38+
{%- if args %}
39+
args: {{ args | tojson }}
40+
{%- endif %}
41+
{%- if environment %}
42+
env:
43+
{%- for key, value in environment.items() %}
44+
- name: {{ key }}
4345
value: "{{ value }}"
44-
{% endfor %}{% endif %}{% if volume %} volumeMounts:
45-
{% for vol in volume %} - name: "{{ vol.name }}"
46-
mountPath: "{{ vol.mount_path }}"
47-
{% if vol.read_only is not none and vol.read_only != "" %} readOnly: {{ vol.read_only | lower }}
48-
{% endif %}{% endfor %}{% endif %} resources:
46+
{%- endfor %}
47+
{%- endif %}
48+
{%- if volume %}
49+
volumeMounts:
50+
{%- for vol in volume %}
51+
- name: {{ vol.name }}
52+
mountPath: {{ vol.mount_path }}
53+
readOnly: {{ vol.read_only | lower if vol.read_only else false }}
54+
{%- endfor %}
55+
{%- endif %}
56+
resources:
4957
requests:
5058
nvidia.com/gpu: "0"
5159
limits:
5260
nvidia.com/gpu: "0"
53-
{% if instance_type or label_selector or deep_health_check_passed_nodes_only %} nodeSelector:
54-
{% if instance_type %} node.kubernetes.io/instance-type: "{{ instance_type }}"
55-
{% endif %}{% if label_selector %}{% for key, value in label_selector.items() %} {{ key }}: "{{ value }}"
56-
{% endfor %}{% endif %}{% if deep_health_check_passed_nodes_only %} deep-health-check-passed: "true"
57-
{% endif %}{% endif %}{% if service_account_name %} serviceAccountName: "{{ service_account_name }}"
58-
{% endif %}{% if scheduler_type %} schedulerName: "{{ scheduler_type }}"
59-
{% endif %}{% if volume %} volumes:
60-
{% for vol in volume %} - name: "{{ vol.name }}"
61-
{% if vol.type == "hostPath" %} hostPath:
62-
path: "{{ vol.path }}"
63-
{% elif vol.type == "pvc" %} persistentVolumeClaim:
64-
claimName: "{{ vol.claim_name }}"
65-
{% endif %}{% endfor %}{% endif %}{% if max_retry %} runPolicy:
61+
{%- if instance_type or label_selector or deep_health_check_passed_nodes_only %}
62+
nodeSelector:
63+
node.kubernetes.io/instance-type: {{ instance_type or "" }}
64+
{%- if label_selector %}
65+
{%- for key, value in label_selector.items() %}
66+
{{ key }}: {{ value }}
67+
{%- endfor %}
68+
{%- endif %}
69+
{%- if deep_health_check_passed_nodes_only %}
70+
deep-health-check-passed: "true"
71+
{%- endif %}
72+
{%- endif %}
73+
{%- if service_account_name %}
74+
serviceAccountName: {{ service_account_name }}
75+
{%- endif %}
76+
{%- if scheduler_type %}
77+
schedulerName: {{ scheduler_type }}
78+
{%- endif %}
79+
{%- if volume %}
80+
volumes:
81+
{%- for vol in volume %}
82+
- name: {{ vol.name }}
83+
{%- if vol.type == "hostPath" %}
84+
hostPath:
85+
path: {{ vol.path }}
86+
{%- elif vol.type == "pvc" %}
87+
persistentVolumeClaim:
88+
claimName: {{ vol.claim_name }}
89+
{%- endif %}
90+
{%- endfor %}
91+
{%- endif %}
92+
{%- if max_retry %}
93+
runPolicy:
6694
cleanPodPolicy: "None"
6795
jobMaxRetryCount: {{ max_retry }}
68-
{% endif %}"""
96+
{%- endif %}"""

0 commit comments

Comments
 (0)