Skip to content

Commit 72c14f7

Browse files
mollyheamazonrsareddy0329
authored andcommitted
Jumpstart and custom inference template agnostic change (#244)
* return SDK class in pytorch model.py for v1_0 and v1_1, update pytorch_create function, update unit test * remove name and namespace from create for inference SDK to match with training SDK, functionality remains the same * fix unit test, add metadata class usage to example notebook, remove skip test * fix unit test again * update integ tests * update create call
1 parent 757d4ec commit 72c14f7

File tree

22 files changed

+239
-98
lines changed

22 files changed

+239
-98
lines changed

doc/cli/inference/cli_inference.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,14 @@ hyp create hyp-jumpstart-endpoint [OPTIONS]
4545
|-----------|------|----------|-------------|
4646
| `--model-id` | TEXT | Yes | JumpStart model identifier (1-63 characters, alphanumeric with hyphens) |
4747
| `--instance-type` | TEXT | Yes | EC2 instance type for inference (must start with "ml.") |
48+
| `--namespace` | TEXT | No | Kubernetes namespace |
49+
| `--metadata-name` | TEXT | No | Name of the jumpstart endpoint object |
4850
| `--accept-eula` | BOOLEAN | No | Whether model terms of use have been accepted (default: false) |
4951
| `--model-version` | TEXT | No | Semantic version of the model (e.g., "1.0.0", 5-14 characters) |
5052
| `--endpoint-name` | TEXT | No | Name of SageMaker endpoint (1-63 characters, alphanumeric with hyphens) |
5153
| `--tls-certificate-output-s3-uri` | TEXT | No | S3 URI to write the TLS certificate (optional) |
54+
| `--debug` | FLAG | No | Enable debug mode (default: false) |
55+
5256

5357
### hyp create hyp-custom-endpoint
5458

@@ -70,6 +74,8 @@ hyp create hyp-custom-endpoint [OPTIONS]
7074
| `--image-uri` | TEXT | Yes | Docker image URI for inference |
7175
| `--container-port` | INTEGER | Yes | Port on which model server listens (1-65535) |
7276
| `--model-volume-mount-name` | TEXT | Yes | Name of the model volume mount |
77+
| `--namespace` | TEXT | No | Kubernetes namespace |
78+
| `--metadata-name` | TEXT | No | Name of the custom endpoint object |
7379
| `--endpoint-name` | TEXT | No | Name of SageMaker endpoint (1-63 characters, alphanumeric with hyphens) |
7480
| `--env` | OBJECT | No | Environment variables as key-value pairs |
7581
| `--metrics-enabled` | BOOLEAN | No | Enable metrics collection (default: false) |
@@ -97,6 +103,8 @@ hyp create hyp-custom-endpoint [OPTIONS]
97103
| `--target-value` | NUMBER | No | Target value for the CloudWatch metric |
98104
| `--use-cached-metrics` | BOOLEAN | No | Enable caching of metric values (default: true) |
99105
| `--invocation-endpoint` | TEXT | No | Invocation endpoint path (default: "invocations") |
106+
| `--debug` | FLAG | No | Enable debug mode (default: false) |
107+
100108

101109
## Inference Endpoint Management Commands
102110

examples/inference/SDK/inference-fsx-model-e2e.ipynb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"source": [
3232
"from sagemaker.hyperpod.inference.config.hp_endpoint_config import FsxStorage, ModelSourceConfig, TlsConfig, EnvironmentVariables, ModelInvocationPort, ModelVolumeMount, Resources, Worker\n",
3333
"from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint\n",
34+
"from sagemaker.hyperpod.common.config.metadata import Metadata\n",
3435
"import yaml\n",
3536
"import time"
3637
]
@@ -42,6 +43,10 @@
4243
"metadata": {},
4344
"outputs": [],
4445
"source": [
46+
"# If you don't set metadata name, it will be default to endpoint name\n",
47+
"# If you don't set namespace, it will be default to \"default\"\n",
48+
"metadata=Metadata(name='<metadata_name>', namespace='<namespace>')\n",
49+
"\n",
4550
"tls_config=TlsConfig(tls_certificate_output_s3_uri='s3://<my-tls-bucket-name>')\n",
4651
"\n",
4752
"model_source_config = ModelSourceConfig(\n",
@@ -82,6 +87,7 @@
8287
"outputs": [],
8388
"source": [
8489
"fsx_endpoint = HPEndpoint(\n",
90+
" metadata=metadata,\n",
8591
" endpoint_name='<my-endpoint-name>',\n",
8692
" instance_type='ml.g5.8xlarge',\n",
8793
" model_name='deepseek15b-fsx-test-pysdk',\n",

examples/inference/SDK/inference-jumpstart-e2e.ipynb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"source": [
8787
"from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import Model, Server,SageMakerEndpoint, TlsConfig, EnvironmentVariables\n",
8888
"from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint\n",
89+
"from sagemaker.hyperpod.common.config.metadata import Metadata\n",
8990
"import yaml\n",
9091
"import time"
9192
]
@@ -105,6 +106,10 @@
105106
"metadata": {},
106107
"outputs": [],
107108
"source": [
109+
"# If you don't set metadata name, it will be default to endpoint name\n",
110+
"# If you don't set namespace, it will be default to \"default\"\n",
111+
"metadata=Metadata(name='<metadata_name>', namespace='<namespace>')\n",
112+
"\n",
108113
"# create configs\n",
109114
"model=Model(\n",
110115
" model_id='deepseek-llm-r1-distill-qwen-1-5b'\n",
@@ -116,6 +121,7 @@
116121
"\n",
117122
"# create spec\n",
118123
"js_endpoint=HPJumpStartEndpoint(\n",
124+
" metadata=metadata,\n",
119125
" model=model,\n",
120126
" server=server,\n",
121127
" sage_maker_endpoint=endpoint_name\n",

examples/inference/SDK/inference-s3-model-e2e.ipynb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"source": [
3232
"from sagemaker.hyperpod.inference.config.hp_endpoint_config import CloudWatchTrigger, Dimensions, AutoScalingSpec, Metrics, S3Storage, ModelSourceConfig, TlsConfig, EnvironmentVariables, ModelInvocationPort, ModelVolumeMount, Resources, Worker\n",
3333
"from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint\n",
34+
"from sagemaker.hyperpod.common.config.metadata import Metadata \n",
3435
"import yaml\n",
3536
"import time"
3637
]
@@ -42,6 +43,10 @@
4243
"metadata": {},
4344
"outputs": [],
4445
"source": [
46+
"# If you don't set metadata name, it will be default to endpoint name\n",
47+
"# If you don't set namespace, it will be default to \"default\"\n",
48+
"metadata=Metadata(name='<metadata_name>', namespace='<namespace>')\n",
49+
"\n",
4550
"tls_config=TlsConfig(tls_certificate_output_s3_uri='s3://<my-tls-bucket-name>')\n",
4651
"\n",
4752
"model_source_config = ModelSourceConfig(\n",
@@ -83,6 +88,7 @@
8388
"outputs": [],
8489
"source": [
8590
"s3_endpoint = HPEndpoint(\n",
91+
" metadata=metadata,\n",
8692
" endpoint_name='<my-endpoint-name>',\n",
8793
" instance_type='ml.g5.8xlarge',\n",
8894
" model_name='deepseek15b-test-model-name', \n",

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,22 @@
2929
CloudWatchTrigger
3030
)
3131
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
32+
from sagemaker.hyperpod.common.config.metadata import Metadata
33+
3234

3335
class FlatHPEndpoint(BaseModel):
3436
model_config = ConfigDict(extra="forbid")
3537

38+
namespace: Optional[str] = Field(
39+
default="default",
40+
description="Kubernetes namespace",
41+
min_length=1
42+
)
43+
3644
metadata_name: Optional[str] = Field(
3745
None,
3846
alias="metadata_name",
39-
description="Name of the jumpstart endpoint object",
47+
description="Name of the custom endpoint object",
4048
max_length=63,
4149
pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$",
4250
)
@@ -255,7 +263,18 @@ def validate_model_source_config(self):
255263
raise ValueError("fsx_file_system_id is required when model_source_type is 'fsx'")
256264
return self
257265

266+
@model_validator(mode='after')
267+
def validate_name(self):
268+
if not self.metadata_name and not self.endpoint_name:
269+
raise ValueError("Either metadata_name or endpoint_name must be provided")
270+
return self
271+
258272
def to_domain(self) -> HPEndpoint:
273+
if self.endpoint_name and not self.metadata_name:
274+
self.metadata_name = self.endpoint_name
275+
276+
metadata = Metadata(name=self.metadata_name, namespace=self.namespace)
277+
259278
env_vars = None
260279
if self.env:
261280
env_vars = [
@@ -337,6 +356,7 @@ def to_domain(self) -> HPEndpoint:
337356
resources=resources,
338357
)
339358
return HPEndpoint(
359+
metadata=metadata,
340360
endpoint_name=self.endpoint_name,
341361
instance_type=self.instance_type,
342362
metrics=metrics,

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/schema.json

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
{
22
"additionalProperties": false,
33
"properties": {
4+
"namespace": {
5+
"anyOf": [
6+
{
7+
"minLength": 1,
8+
"type": "string"
9+
},
10+
{
11+
"type": "null"
12+
}
13+
],
14+
"default": "default",
15+
"description": "Kubernetes namespace",
16+
"title": "Namespace"
17+
},
418
"metadata_name": {
519
"anyOf": [
620
{
@@ -13,7 +27,7 @@
1327
}
1428
],
1529
"default": null,
16-
"description": "Name of the jumpstart endpoint object",
30+
"description": "Name of the custom endpoint object",
1731
"title": "Metadata Name"
1832
},
1933
"endpoint_name": {

hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,17 @@
2121
TlsConfig
2222
)
2323
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint
24+
from sagemaker.hyperpod.common.config.metadata import Metadata
2425

2526
class FlatHPJumpStartEndpoint(BaseModel):
2627
model_config = ConfigDict(extra="forbid")
2728

29+
namespace: Optional[str] = Field(
30+
default="default",
31+
description="Kubernetes namespace",
32+
min_length=1
33+
)
34+
2835
accept_eula: bool = Field(
2936
False, alias="accept_eula", description="Whether model terms of use have been accepted"
3037
)
@@ -76,8 +83,18 @@ class FlatHPJumpStartEndpoint(BaseModel):
7683
pattern=r"^s3://([^/]+)/?(.*)$",
7784
)
7885

86+
@model_validator(mode='after')
87+
def validate_name(self):
88+
if not self.metadata_name and not self.endpoint_name:
89+
raise ValueError("Either metadata_name or endpoint_name must be provided")
90+
91+
7992
def to_domain(self) -> HPJumpStartEndpoint:
80-
# Build nested domain (pydantic) objects
93+
if self.endpoint_name and not self.metadata_name:
94+
self.metadata_name = self.endpoint_name
95+
96+
metadata = Metadata(name=self.metadata_name, namespace=self.namespace)
97+
8198
model = Model(
8299
accept_eula=self.accept_eula,
83100
model_id=self.model_id,
@@ -91,6 +108,7 @@ def to_domain(self) -> HPJumpStartEndpoint:
91108
TlsConfig(tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri)
92109
)
93110
return HPJumpStartEndpoint(
111+
metadata=metadata,
94112
model=model,
95113
server=server,
96114
sage_maker_endpoint=sage_ep,

hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/schema.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
{
22
"additionalProperties": false,
33
"properties": {
4+
"namespace": {
5+
"anyOf": [
6+
{
7+
"minLength": 1,
8+
"type": "string"
9+
},
10+
{
11+
"type": "null"
12+
}
13+
],
14+
"default": "default",
15+
"description": "Kubernetes namespace",
16+
"title": "Namespace"
17+
},
418
"accept_eula": {
519
"default": false,
620
"description": "Whether model terms of use have been accepted",

src/sagemaker/hyperpod/cli/commands/inference.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,38 @@
2020

2121
# CREATE
2222
@click.command("hyp-jumpstart-endpoint")
23-
@click.option(
24-
"--namespace",
25-
type=click.STRING,
26-
required=False,
27-
default="default",
28-
help="Optional. The namespace of the jumpstart model endpoint to create. Default set to 'default'",
29-
)
3023
@click.option("--version", default="1.0", help="Schema version to use")
24+
@click.option("--debug", default=False, help="Enable debug mode")
3125
@generate_click_command(
3226
schema_pkg="hyperpod_jumpstart_inference_template",
3327
registry=JS_REG,
3428
)
3529
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_js_endpoint_cli")
3630
@handle_cli_exceptions()
37-
def js_create(name, namespace, version, js_endpoint):
31+
def js_create(version, debug, js_endpoint):
3832
"""
3933
Create a jumpstart model endpoint.
4034
"""
41-
42-
js_endpoint.create(name=name, namespace=namespace)
35+
click.echo(f"Using version: {version}")
36+
js_endpoint.create(debug=debug)
4337

4438

4539
@click.command("hyp-custom-endpoint")
46-
@click.option(
47-
"--namespace",
48-
type=click.STRING,
49-
required=False,
50-
default="default",
51-
help="Optional. The namespace of the jumpstart model endpoint to create. Default set to 'default'",
52-
)
5340
@click.option("--version", default="1.0", help="Schema version to use")
41+
@click.option("--debug", default=False, help="Enable debug mode")
5442
@generate_click_command(
5543
schema_pkg="hyperpod_custom_inference_template",
5644
registry=C_REG,
5745
)
5846
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_custom_endpoint_cli")
5947
@handle_cli_exceptions()
60-
def custom_create(name, namespace, version, custom_endpoint):
48+
def custom_create(version, debug, custom_endpoint):
6149
"""
6250
Create a custom model endpoint.
6351
"""
64-
65-
custom_endpoint.create(name=name, namespace=namespace)
66-
52+
click.echo(f"Using version: {version}")
53+
custom_endpoint.create(debug=debug)
54+
6755

6856
# INVOKE
6957
@click.command("hyp-custom-endpoint")

src/sagemaker/hyperpod/cli/commands/init.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,7 @@ def _default_create(region):
393393
filtered_config = _filter_cli_metadata_fields(data)
394394
flat = model(**filtered_config)
395395
domain = flat.to_domain()
396-
# TODO: update inference SDK to include name and namespace in the call
397-
if template == "hyp-custom-endpoint" or template == "hyp-jumpstart-endpoint":
398-
domain.create(namespace=namespace)
399-
elif template == "hyp-pytorch-job":
400-
domain.create()
396+
domain.create()
401397

402398

403399
except Exception as e:

0 commit comments

Comments
 (0)