Skip to content

Commit b063853

Browse files
Merge branch 'main' into 404error
2 parents 3bece5a + f747815 commit b063853

File tree

14 files changed

+512
-177
lines changed

14 files changed

+512
-177
lines changed

README.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ hyp create hyp-pytorch-job \
171171
--priority "high" \
172172
--max-retry 3 \
173173
--volume name=model-data,type=hostPath,mount_path=/data,path=/data \
174-
--volume name=training-output,type=pvc,mount_path=/data,claim_name=my-pvc,read_only=false
174+
--volume name=training-output,type=pvc,mount_path=/data2,claim_name=my-pvc,read_only=false
175175
```
176176
177177
Key required parameters explained:
@@ -192,7 +192,6 @@ hyp create hyp-jumpstart-endpoint \
192192
--model-id jumpstart-model-id\
193193
--instance-type ml.g5.8xlarge \
194194
--endpoint-name endpoint-jumpstart \
195-
--tls-output-s3-uri s3://sample-bucket
196195
```
197196
198197
@@ -219,7 +218,8 @@ hyp create hyp-custom-endpoint \
219218
--endpoint-name my-custom-endpoint \
220219
--model-name my-pytorch-model \
221220
--model-source-type s3 \
222-
--model-location my-pytorch-training/model.tar.gz \
221+
--model-location my-pytorch-training \
222+
--model-volume-mount-name test-volume \
223223
--s3-bucket-name your-bucket \
224224
--s3-region us-east-1 \
225225
--instance-type ml.g5.8xlarge \
@@ -333,20 +333,17 @@ from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import Mod
333333
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint
334334

335335
model=Model(
336-
model_id='deepseek-llm-r1-distill-qwen-1-5b',
337-
model_version='2.0.4',
336+
model_id='deepseek-llm-r1-distill-qwen-1-5b'
338337
)
339338
server=Server(
340339
instance_type='ml.g5.8xlarge',
341340
)
342341
endpoint_name=SageMakerEndpoint(name='<my-endpoint-name>')
343-
tls_config=TlsConfig(tls_certificate_output_s3_uri='s3://<my-tls-bucket>')
344342

345343
js_endpoint=HPJumpStartEndpoint(
346344
model=model,
347345
server=server,
348-
sage_maker_endpoint=endpoint_name,
349-
tls_config=tls_config,
346+
sage_maker_endpoint=endpoint_name
350347
)
351348

352349
js_endpoint.create()

doc/inference.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import Mod
3737
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint
3838
3939
model = Model(
40-
model_id="deepseek-llm-r1-distill-qwen-1-5b",
41-
model_version="2.0.4"
40+
model_id="deepseek-llm-r1-distill-qwen-1-5b"
4241
)
4342
4443
server = Server(
@@ -47,13 +46,10 @@ server = Server(
4746
4847
endpoint_name = SageMakerEndpoint(name="endpoint-jumpstart")
4948
50-
tls_config = TlsConfig(tls_certificate_output_s3_uri="s3://sample-bucket")
51-
5249
js_endpoint = HPJumpStartEndpoint(
5350
model=model,
5451
server=server,
55-
sage_maker_endpoint=endpoint_name,
56-
tls_config=tls_config
52+
sage_maker_endpoint=endpoint_name
5753
)
5854
5955
js_endpoint.create()
@@ -85,7 +81,7 @@ from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
8581
8682
model = Model(
8783
model_source_type="s3",
88-
model_location="test-pytorch-job/model.tar.gz",
84+
model_location="test-pytorch-job",
8985
s3_bucket_name="my-bucket",
9086
s3_region="us-east-2",
9187
prefetch_enabled=True

examples/inference/SDK/inference-jumpstart-e2e.ipynb

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,18 @@
107107
"source": [
108108
"# create configs\n",
109109
"model=Model(\n",
110-
" model_id='deepseek-llm-r1-distill-qwen-1-5b',\n",
111-
" model_version='2.0.4',\n",
110+
" model_id='deepseek-llm-r1-distill-qwen-1-5b'\n",
112111
")\n",
113112
"server=Server(\n",
114113
" instance_type='ml.g5.8xlarge',\n",
115114
")\n",
116115
"endpoint_name=SageMakerEndpoint(name='<my-endpoint-name>')\n",
117-
"tls_config=TlsConfig(tls_certificate_output_s3_uri='s3://<my-tls-bucket>')\n",
118116
"\n",
119117
"# create spec\n",
120118
"js_endpoint=HPJumpStartEndpoint(\n",
121119
" model=model,\n",
122120
" server=server,\n",
123-
" sage_maker_endpoint=endpoint_name,\n",
124-
" tls_config=tls_config,\n",
121+
" sage_maker_endpoint=endpoint_name\n",
125122
")"
126123
]
127124
},

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from pydantic import BaseModel, Field
13+
from pydantic import BaseModel, Field, model_validator, ConfigDict
1414
from typing import Optional, List, Dict, Union, Literal
1515

1616
from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
@@ -31,9 +31,19 @@
3131
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
3232

3333
class FlatHPEndpoint(BaseModel):
34+
model_config = ConfigDict(extra="forbid")
35+
36+
metadata_name: Optional[str] = Field(
37+
None,
38+
alias="metadata_name",
39+
description="Name of the jumpstart endpoint object",
40+
max_length=63,
41+
pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$",
42+
)
43+
3444
# endpoint_name
3545
endpoint_name: Optional[str] = Field(
36-
"",
46+
None,
3747
alias="endpoint_name",
3848
description="Name of SageMaker endpoint; empty string means no creation",
3949
max_length=63,
@@ -130,7 +140,7 @@ class FlatHPEndpoint(BaseModel):
130140
description="FSX File System DNS Name",
131141
)
132142
fsx_file_system_id: Optional[str] = Field(
133-
...,
143+
None,
134144
alias="fsx_file_system_id",
135145
description="FSX File System ID",
136146
)
@@ -142,12 +152,12 @@ class FlatHPEndpoint(BaseModel):
142152

143153
# S3Storage
144154
s3_bucket_name: Optional[str] = Field(
145-
...,
155+
None,
146156
alias="s3_bucket_name",
147157
description="S3 bucket location",
148158
)
149159
s3_region: Optional[str] = Field(
150-
...,
160+
None,
151161
alias="s3_region",
152162
description="S3 bucket region",
153163
)
@@ -229,12 +239,22 @@ class FlatHPEndpoint(BaseModel):
229239
invocation_endpoint: Optional[str] = Field(
230240
default="invocations",
231241
description=(
232-
"The invocation endpoint of the model server. "
233-
"http://<host>:<port>/ would be pre-populated based on the other fields. "
242+
"The invocation endpoint of the model server. http://<host>:<port>/ would be pre-populated based on the other fields. "
234243
"Please fill in the path after http://<host>:<port>/ specific to your model server.",
235244
)
236245
)
237-
246+
247+
@model_validator(mode='after')
248+
def validate_model_source_config(self):
249+
"""Validate that required fields are provided based on model_source_type"""
250+
if self.model_source_type == "s3":
251+
if not self.s3_bucket_name or not self.s3_region:
252+
raise ValueError("s3_bucket_name and s3_region are required when model_source_type is 's3'")
253+
elif self.model_source_type == "fsx":
254+
if not self.fsx_file_system_id:
255+
raise ValueError("fsx_file_system_id is required when model_source_type is 'fsx'")
256+
return self
257+
238258
def to_domain(self) -> HPEndpoint:
239259
env_vars = None
240260
if self.env:

0 commit comments

Comments
 (0)