Skip to content

Commit 312759c

Browse files
committed
add metadata_name argument to js and custom endpoint to match with SDK
1 parent 7fda684 commit 312759c

File tree

10 files changed

+500
-154
lines changed

10 files changed

+500
-154
lines changed

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from pydantic import BaseModel, Field
13+
from pydantic import BaseModel, Field, model_validator, ConfigDict
1414
from typing import Optional, List, Dict, Union, Literal
1515

1616
from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
@@ -31,9 +31,19 @@
3131
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
3232

3333
class FlatHPEndpoint(BaseModel):
34+
model_config = ConfigDict(extra="forbid")
35+
36+
metadata_name: Optional[str] = Field(
37+
None,
38+
alias="metadata_name",
39+
description="Name of the jumpstart endpoint object",
40+
max_length=63,
41+
pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$",
42+
)
43+
3444
# endpoint_name
3545
endpoint_name: Optional[str] = Field(
36-
"",
46+
None,
3747
alias="endpoint_name",
3848
description="Name of SageMaker endpoint; empty string means no creation",
3949
max_length=63,
@@ -130,7 +140,7 @@ class FlatHPEndpoint(BaseModel):
130140
description="FSX File System DNS Name",
131141
)
132142
fsx_file_system_id: Optional[str] = Field(
133-
...,
143+
None,
134144
alias="fsx_file_system_id",
135145
description="FSX File System ID",
136146
)
@@ -142,12 +152,12 @@ class FlatHPEndpoint(BaseModel):
142152

143153
# S3Storage
144154
s3_bucket_name: Optional[str] = Field(
145-
...,
155+
None,
146156
alias="s3_bucket_name",
147157
description="S3 bucket location",
148158
)
149159
s3_region: Optional[str] = Field(
150-
...,
160+
None,
151161
alias="s3_region",
152162
description="S3 bucket region",
153163
)
@@ -229,12 +239,22 @@ class FlatHPEndpoint(BaseModel):
229239
invocation_endpoint: Optional[str] = Field(
230240
default="invocations",
231241
description=(
232-
"The invocation endpoint of the model server. "
233-
"http://<host>:<port>/ would be pre-populated based on the other fields. "
242+
"The invocation endpoint of the model server. http://<host>:<port>/ would be pre-populated based on the other fields. "
234243
"Please fill in the path after http://<host>:<port>/ specific to your model server.",
235244
)
236245
)
237-
246+
247+
@model_validator(mode='after')
248+
def validate_model_source_config(self):
249+
"""Validate that required fields are provided based on model_source_type"""
250+
if self.model_source_type == "s3":
251+
if not self.s3_bucket_name or not self.s3_region:
252+
raise ValueError("s3_bucket_name and s3_region are required when model_source_type is 's3'")
253+
elif self.model_source_type == "fsx":
254+
if not self.fsx_dns_name or not self.fsx_file_system_id or not self.fsx_mount_name:
255+
raise ValueError("fsx_dns_name, fsx_file_system_id and fsx_mount_name are required when model_source_type is 'fsx'")
256+
return self
257+
238258
def to_domain(self) -> HPEndpoint:
239259
env_vars = None
240260
if self.env:

0 commit comments

Comments
 (0)