Skip to content

Commit 52f2ac7

Browse files
fix: update inference template.py to reflect init experience template submission changes
1 parent 906ac34 commit 52f2ac7

File tree

1 file changed

+75
-55
lines changed
  • hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_1

1 file changed

+75
-55
lines changed
Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,90 @@
1-
TEMPLATE_CONTENT = """### Please keep template file unchanged ###
1+
TEMPLATE_CONTENT = """
22
apiVersion: hyperpod.sagemaker.aws/v1
3-
kind: HPEndpoint
3+
kind: InferenceEndpointConfig
44
metadata:
5-
name: "{{ endpoint_name }}"
6-
namespace: "{{ namespace }}"
5+
name: {{ metadata_name or endpoint_name }}
6+
namespace: {{ namespace }}
77
spec:
8-
instanceType: "{{ instance_type }}"
9-
modelName: "{{ model_name }}"
10-
{% if model_version is not none %} modelVersion: "{{ model_version }}"
11-
{% endif %}
12-
env:
13-
{% if env %}
14-
{% for key, val in env.items() %} - name: "{{ key }}"
15-
value: "{{ val }}"
16-
{% endfor %}{% else %} []
17-
{% endif %}
8+
endpointName: {{ endpoint_name }}
9+
instanceType: {{ instance_type }}
10+
modelName: {{ model_name }}
11+
modelVersion: {{ model_version or "" }}
12+
1813
metrics:
19-
enabled: {{ metrics_enabled }}
14+
enabled: {{ metrics_enabled or False }}
15+
2016
modelSourceConfig:
21-
modelSourceType: "{{ model_source_type }}"
22-
{% if model_location is not none %} modelLocation: "{{ model_location }}"
23-
{% endif %} prefetchEnabled: {{ prefetch_enabled }}
24-
{% if model_source_type == "s3" %} s3Storage:
25-
bucketName: "{{ s3_bucket_name }}"
26-
region: "{{ s3_region }}"
27-
{% elif model_source_type == "fsx" %} fsxStorage:
28-
dnsName: "{{ fsx_dns_name }}"
29-
fileSystemId: "{{ fsx_file_system_id }}"
30-
{% if fsx_mount_name is not none %} mountName: "{{ fsx_mount_name }}"
31-
{% endif %}{% endif %}
17+
modelSourceType: {{ model_source_type }}
18+
modelLocation: {{ model_location or "" }}
19+
prefetchEnabled: {{ prefetch_enabled or False }}
20+
{%- if model_source_type == "s3" %}
21+
s3Storage:
22+
bucketName: {{ s3_bucket_name }}
23+
region: {{ s3_region }}
24+
{%- elif model_source_type == "fsx" %}
25+
fsxStorage:
26+
dnsName: {{ fsx_dns_name }}
27+
fileSystemId: {{ fsx_file_system_id }}
28+
mountName: {{ fsx_mount_name or "" }}
29+
{%- endif %}
30+
3231
tlsConfig:
33-
{% if tls_certificate_output_s3_uri is not none %} certificateOutputS3Uri: "{{ tls_certificate_output_s3_uri }}"
34-
{% else %} {}
35-
{% endif %}
32+
tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }}
33+
3634
worker:
37-
image: "{{ image_uri }}"
38-
containerPort: {{ container_port }}
39-
volumeMount:
40-
name: "{{ model_volume_mount_name }}"
41-
mountPath: "{{ model_volume_mount_path }}"
35+
environmentVariables:
36+
{%- if env %}
37+
{%- for key, val in env.items() %}
38+
- name: {{ key }}
39+
value: "{{ val }}"
40+
{%- endfor %}
41+
{%- else %}
42+
[]
43+
{%- endif %}
44+
image: {{ image_uri }}
45+
modelInvocationPort:
46+
containerPort: {{ container_port }}
47+
modelVolumeMount:
48+
name: {{ model_volume_mount_name }}
49+
mountPath: {{ model_volume_mount_path }}
4250
resources:
43-
{% if resources_limits %} limits:
44-
{% for key, val in resources_limits.items() %} {{ key }}: "{{ val }}"
45-
{% endfor %}{% else %} {}
46-
{% endif %}{% if resources_requests %}
51+
{%- if resources_limits %}
52+
limits:
53+
{%- for key, val in resources_limits.items() %}
54+
{{ key }}: {{ val }}
55+
{%- endfor %}
56+
{%- else %}
57+
{}
58+
{%- endif %}
59+
{%- if resources_requests %}
4760
requests:
48-
{% for key, val in resources_requests.items() %} {{ key }}: "{{ val }}"
49-
{% endfor %}{% endif %}
61+
{%- for key, val in resources_requests.items() %}
62+
{{ key }}: {{ val }}
63+
{%- endfor %}
64+
{%- endif %}
65+
5066
autoScalingSpec:
5167
cloudWatchTrigger:
52-
{% if dimensions %} dimensions:
53-
{% for dim_key, dim_val in dimensions.items() %} - name: "{{ dim_key }}"
54-
value: "{{ dim_val }}"
55-
{% endfor %}{% else %} []
56-
{% endif %} metricCollectionPeriod: {{ metric_collection_period }}
68+
{%- if dimensions %}
69+
dimensions:
70+
{%- for dim_key, dim_val in dimensions.items() %}
71+
- name: {{ dim_key }}
72+
value: {{ dim_val }}
73+
{%- endfor %}
74+
{%- endif %}
75+
metricCollectionPeriod: {{ metric_collection_period }}
5776
metricCollectionStartTime: {{ metric_collection_start_time }}
58-
metricName: "{{ metric_name }}"
59-
metricStat: "{{ metric_stat }}"
60-
type: "{{ metric_type }}"
61-
minValue: {{ min_value }}
62-
name: "{{ cloud_watch_trigger_name }}"
63-
namespace: "{{ cloud_watch_trigger_namespace }}"
64-
targetValue: {{ target_value }}
65-
useCachedMetrics: {{ use_cached_metrics }}
77+
metricName: {{ metric_name or "" }}
78+
metricStat: {{ metric_stat }}
79+
metricType: {{ metric_type }}
80+
minValue: {{ min_value }}
81+
name: {{ cloud_watch_trigger_name or "" }}
82+
namespace: {{ cloud_watch_trigger_namespace or "" }}
83+
targetValue: {{ target_value or "" }}
84+
useCachedMetrics: {{ use_cached_metrics or False }}
85+
6686
invocationEndpoint: "{{ invocation_endpoint }}"
87+
6788
{% if intelligent_routing_enabled is not none %} intelligentRoutingSpec:
6889
enabled: {{ intelligent_routing_enabled }}
6990
{% if routing_strategy is not none %} routingStrategy: "{{ routing_strategy }}"{% endif %}{% endif %}
@@ -76,5 +97,4 @@
7697
{% endif %}
7798
{% if cache_config_file is not none %} cacheConfigFile: "{{ cache_config_file }}"{% endif %}
7899
{% endif %}
79-
80-
"""
100+
"""

0 commit comments

Comments
 (0)