1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- from pydantic import BaseModel , Field
13+ from pydantic import BaseModel , Field , model_validator , ConfigDict
1414from typing import Optional , List , Dict , Union , Literal
1515
1616from sagemaker .hyperpod .inference .config .hp_endpoint_config import (
3131from sagemaker .hyperpod .inference .hp_endpoint import HPEndpoint
3232
3333class FlatHPEndpoint (BaseModel ):
34+ model_config = ConfigDict (extra = "forbid" )
35+
36+ metadata_name : Optional [str ] = Field (
37+ None ,
38+ alias = "metadata_name" ,
39+ description = "Name of the jumpstart endpoint object" ,
40+ max_length = 63 ,
41+ pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$" ,
42+ )
43+
3444 # endpoint_name
3545 endpoint_name : Optional [str ] = Field (
36- "" ,
46+ None ,
3747 alias = "endpoint_name" ,
3848 description = "Name of SageMaker endpoint; empty string means no creation" ,
3949 max_length = 63 ,
@@ -130,7 +140,7 @@ class FlatHPEndpoint(BaseModel):
130140 description = "FSX File System DNS Name" ,
131141 )
132142 fsx_file_system_id : Optional [str ] = Field (
133- ... ,
143+ None ,
134144 alias = "fsx_file_system_id" ,
135145 description = "FSX File System ID" ,
136146 )
@@ -142,12 +152,12 @@ class FlatHPEndpoint(BaseModel):
142152
143153 # S3Storage
144154 s3_bucket_name : Optional [str ] = Field (
145- ... ,
155+ None ,
146156 alias = "s3_bucket_name" ,
147157 description = "S3 bucket location" ,
148158 )
149159 s3_region : Optional [str ] = Field (
150- ... ,
160+ None ,
151161 alias = "s3_region" ,
152162 description = "S3 bucket region" ,
153163 )
@@ -229,12 +239,22 @@ class FlatHPEndpoint(BaseModel):
229239 invocation_endpoint : Optional [str ] = Field (
230240 default = "invocations" ,
231241 description = (
232- "The invocation endpoint of the model server. "
233- "http://<host>:<port>/ would be pre-populated based on the other fields. "
242+ "The invocation endpoint of the model server. http://<host>:<port>/ would be pre-populated based on the other fields. "
234243 "Please fill in the path after http://<host>:<port>/ specific to your model server." ,
235244 )
236245 )
237-
246+
247+ @model_validator (mode = 'after' )
248+ def validate_model_source_config (self ):
249+ """Validate that required fields are provided based on model_source_type"""
250+ if self .model_source_type == "s3" :
251+ if not self .s3_bucket_name or not self .s3_region :
252+ raise ValueError ("s3_bucket_name and s3_region are required when model_source_type is 's3'" )
253+ elif self .model_source_type == "fsx" :
254+ if not self .fsx_dns_name or not self .fsx_file_system_id or not self .fsx_mount_name :
255+ raise ValueError ("fsx_dns_name, fsx_file_system_id and fsx_mount_name are required when model_source_type is 'fsx'" )
256+ return self
257+
238258 def to_domain (self ) -> HPEndpoint :
239259 env_vars = None
240260 if self .env :
0 commit comments