2727 Worker ,
2828 Dimensions ,
2929 AutoScalingSpec ,
30- CloudWatchTrigger
30+ CloudWatchTrigger ,
3131)
3232from sagemaker .hyperpod .inference .hp_endpoint import HPEndpoint
3333from sagemaker .hyperpod .common .config .metadata import Metadata
@@ -37,12 +37,10 @@ class FlatHPEndpoint(BaseModel):
3737 model_config = ConfigDict (extra = "forbid" )
3838
3939 namespace : Optional [str ] = Field (
40- default = None ,
41- description = "Kubernetes namespace" ,
42- min_length = 1
40+ default = None , description = "Kubernetes namespace" , min_length = 1
4341 )
4442
45- metadata_name : Optional [str ] = Field (
43+ metadata_name : Optional [str ] = Field (
4644 None ,
4745 alias = "metadata_name" ,
4846 description = "Name of the custom endpoint object" ,
@@ -75,14 +73,15 @@ class FlatHPEndpoint(BaseModel):
7573
7674 # metrics.*
7775 metrics_enabled : Optional [bool ] = Field (
78- False , alias = "metrics_enabled" ,
76+ False ,
77+ alias = "metrics_enabled" ,
7978 description = "Enable metrics collection" ,
8079 )
8180
8281 # model_name and version
8382 model_name : str = Field (
84- ...,
85- alias = "model_name" ,
83+ ...,
84+ alias = "model_name" ,
8685 description = "Name of model to create on SageMaker" ,
8786 min_length = 1 ,
8887 max_length = 63 ,
@@ -100,15 +99,18 @@ class FlatHPEndpoint(BaseModel):
10099
101100 # model_source_config.*
102101 model_source_type : Literal ["fsx" , "s3" ] = Field (
103- ..., alias = "model_source_type" ,
102+ ...,
103+ alias = "model_source_type" ,
104104 description = "Source type: fsx or s3" ,
105105 )
106106 model_location : Optional [str ] = Field (
107- None , alias = "model_location" ,
107+ None ,
108+ alias = "model_location" ,
108109 description = "Specific model data location" ,
109110 )
110111 prefetch_enabled : Optional [bool ] = Field (
111- False , alias = "prefetch_enabled" ,
112+ False ,
113+ alias = "prefetch_enabled" ,
112114 description = "Whether to pre-fetch model data" ,
113115 )
114116
@@ -122,11 +124,12 @@ class FlatHPEndpoint(BaseModel):
122124
123125 # worker.*
124126 image_uri : str = Field (
125- ..., alias = "image_uri" ,
127+ ...,
128+ alias = "image_uri" ,
126129 description = "Inference server image name" ,
127130 )
128131 container_port : int = Field (
129- ...,
132+ ...,
130133 alias = "container_port" ,
131134 description = "Port on which the model server listens" ,
132135 ge = 1 ,
@@ -138,7 +141,8 @@ class FlatHPEndpoint(BaseModel):
138141 description = "Path inside container for model volume" ,
139142 )
140143 model_volume_mount_name : str = Field (
141- ..., alias = "model_volume_mount_name" ,
144+ ...,
145+ alias = "model_volume_mount_name" ,
142146 description = "Name of the model volume mount" ,
143147 )
144148
@@ -149,7 +153,7 @@ class FlatHPEndpoint(BaseModel):
149153 description = "FSX File System DNS Name" ,
150154 )
151155 fsx_file_system_id : Optional [str ] = Field (
152- None ,
156+ None ,
153157 alias = "fsx_file_system_id" ,
154158 description = "FSX File System ID" ,
155159 )
@@ -161,23 +165,23 @@ class FlatHPEndpoint(BaseModel):
161165
162166 # S3Storage
163167 s3_bucket_name : Optional [str ] = Field (
164- None ,
168+ None ,
165169 alias = "s3_bucket_name" ,
166170 description = "S3 bucket location" ,
167171 )
168172 s3_region : Optional [str ] = Field (
169- None ,
173+ None ,
170174 alias = "s3_region" ,
171175 description = "S3 bucket region" ,
172176 )
173177
174178 # Resources
175- resources_limits : Optional [Dict [str , Union [int ,str ]]] = Field (
179+ resources_limits : Optional [Dict [str , Union [int , str ]]] = Field (
176180 None ,
177181 alias = "resources_limits" ,
178182 description = "Resource limits for the worker" ,
179183 )
180- resources_requests : Optional [Dict [str , Union [int ,str ]]] = Field (
184+ resources_requests : Optional [Dict [str , Union [int , str ]]] = Field (
181185 None ,
182186 alias = "resources_requests" ,
183187 description = "Resource requests for the worker" ,
@@ -187,84 +191,82 @@ class FlatHPEndpoint(BaseModel):
187191 dimensions : Optional [Dict [str , str ]] = Field (
188192 None ,
189193 alias = "dimensions" ,
190- description = "CloudWatch Metric dimensions as key–value pairs"
194+ description = "CloudWatch Metric dimensions as key–value pairs" ,
191195 )
192196
193197 # CloudWatch Trigger
194198 metric_collection_period : Optional [int ] = Field (
195- 300 ,
196- description = "Defines the Period for CloudWatch query"
199+ 300 , description = "Defines the Period for CloudWatch query"
197200 )
198201 metric_collection_start_time : Optional [int ] = Field (
199- 300 ,
200- description = "Defines the StartTime for CloudWatch query"
202+ 300 , description = "Defines the StartTime for CloudWatch query"
201203 )
202204 metric_name : Optional [str ] = Field (
203- None ,
204- description = "Metric name to query for CloudWatch trigger"
205+ None , description = "Metric name to query for CloudWatch trigger"
205206 )
206207 metric_stat : Optional [str ] = Field (
207208 "Average" ,
208209 description = (
209210 "Statistics metric to be used by Trigger. "
210211 "Defines the Stat for the CloudWatch query. Default is Average."
211- )
212+ ),
212213 )
213214 metric_type : Optional [Literal ["Value" , "Average" ]] = Field (
214215 "Average" ,
215216 description = (
216217 "The type of metric to be used by HPA. "
217218 "`Average` – Uses average value per pod; "
218219 "`Value` – Uses absolute metric value."
219- )
220+ ),
220221 )
221222 min_value : Optional [float ] = Field (
222223 0 ,
223224 description = (
224225 "Minimum metric value used in case of empty response "
225226 "from CloudWatch. Default is 0."
226- )
227+ ),
227228 )
228229 cloud_watch_trigger_name : Optional [str ] = Field (
229- None ,
230- description = "Name for the CloudWatch trigger"
230+ None , description = "Name for the CloudWatch trigger"
231231 )
232232 cloud_watch_trigger_namespace : Optional [str ] = Field (
233- None ,
234- description = "AWS CloudWatch namespace for the metric"
233+ None , description = "AWS CloudWatch namespace for the metric"
235234 )
236235 target_value : Optional [float ] = Field (
237- None ,
238- description = "Target value for the CloudWatch metric"
236+ None , description = "Target value for the CloudWatch metric"
239237 )
240238 use_cached_metrics : Optional [bool ] = Field (
241239 True ,
242240 description = (
243241 "Enable caching of metric values during polling interval. "
244242 "Default is true."
245- )
243+ ),
246244 )
247245
248246 invocation_endpoint : Optional [str ] = Field (
249247 default = "invocations" ,
250248 description = (
251249 "The invocation endpoint of the model server. http://<host>:<port>/ would be pre-populated based on the other fields. "
252250 "Please fill in the path after http://<host>:<port>/ specific to your model server." ,
253- )
251+ ),
254252 )
255253
256- @model_validator (mode = ' after' )
254+ @model_validator (mode = " after" )
257255 def validate_model_source_config (self ):
258256 """Validate that required fields are provided based on model_source_type"""
259257 if self .model_source_type == "s3" :
260258 if not self .s3_bucket_name or not self .s3_region :
261- raise ValueError ("s3_bucket_name and s3_region are required when model_source_type is 's3'" )
259+ raise ValueError (
260+ "s3_bucket_name and s3_region are required when model_source_type is 's3'"
261+ )
262262 elif self .model_source_type == "fsx" :
263263 if not self .fsx_file_system_id :
264- raise ValueError ("fsx_file_system_id is required when model_source_type is 'fsx'" )
264+ raise ValueError (
265+ "fsx_file_system_id is required when model_source_type is 'fsx'"
266+ )
265267 return self
266268
267- @model_validator (mode = ' after' )
269+ @model_validator (mode = " after" )
268270 def validate_name (self ):
269271 if not self .metadata_name and not self .endpoint_name :
270272 raise ValueError ("Either metadata_name or endpoint_name must be provided" )
@@ -273,21 +275,20 @@ def validate_name(self):
273275 def to_domain (self ) -> HPEndpoint :
274276 if self .endpoint_name and not self .metadata_name :
275277 self .metadata_name = self .endpoint_name
276-
278+
277279 metadata = Metadata (name = self .metadata_name , namespace = self .namespace )
278280
279281 env_vars = None
280282 if self .env :
281283 env_vars = [
282- EnvironmentVariables (name = k , value = v )
283- for k , v in self .env .items ()
284+ EnvironmentVariables (name = k , value = v ) for k , v in self .env .items ()
284285 ]
285286
286287 dim_vars : list [Dimensions ] = []
287288 if self .dimensions :
288289 for name , value in self .dimensions .items ():
289290 dim_vars .append (Dimensions (name = name , value = value ))
290-
291+
291292 cloud_watch_trigger = CloudWatchTrigger (
292293 dimensions = dim_vars ,
293294 metric_collection_period = self .metric_collection_period ,
@@ -300,12 +301,10 @@ def to_domain(self) -> HPEndpoint:
300301 namespace = self .cloud_watch_trigger_namespace ,
301302 target_value = self .target_value ,
302303 use_cached_metrics = self .use_cached_metrics ,
303- )
304-
305- auto_scaling_spec = AutoScalingSpec (
306- cloud_watch_trigger = cloud_watch_trigger
307304 )
308305
306+ auto_scaling_spec = AutoScalingSpec (cloud_watch_trigger = cloud_watch_trigger )
307+
309308 # nested metrics
310309 metrics = Metrics (
311310 enabled = self .metrics_enabled ,
@@ -336,7 +335,9 @@ def to_domain(self) -> HPEndpoint:
336335 fsx_storage = fsx ,
337336 )
338337
339- tls = TlsConfig (tls_certificate_output_s3_uri = self .tls_certificate_output_s3_uri )
338+ tls = TlsConfig (
339+ tls_certificate_output_s3_uri = self .tls_certificate_output_s3_uri
340+ )
340341
341342 invocation_port = ModelInvocationPort (
342343 container_port = self .container_port ,
@@ -368,4 +369,4 @@ def to_domain(self) -> HPEndpoint:
368369 worker = worker ,
369370 invocation_endpoint = self .invocation_endpoint ,
370371 auto_scaling_spec = auto_scaling_spec
371- )
372+ )
0 commit comments