@@ -170,6 +170,93 @@ def to_json(self) -> Dict[str, Any]:
170170 return json_obj
171171
172172
173+ class JumpStartHyperparameter (JumpStartDataHolderType ):
174+ """Data class for JumpStart hyperparameter."""
175+
176+ __slots__ = {
177+ "name" ,
178+ "type" ,
179+ "options" ,
180+ "default" ,
181+ "scope" ,
182+ "min" ,
183+ "max" ,
184+ }
185+
186+ def __init__ (self , spec : Dict [str , Any ]):
187+ """Initializes a JumpStartHyperparameter object from its json representation.
188+
189+ Args:
190+ spec (Dict[str, Any]): Dictionary representation of hyperparameter.
191+ """
192+ self .from_json (spec )
193+
194+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
195+ """Sets fields in object based on json.
196+
197+ Args:
198+ json_obj (Dict[str, Any]): Dictionary representation of hyperparameter.
199+ """
200+
201+ self .name = json_obj ["name" ]
202+ self .type = json_obj ["type" ]
203+ self .default = json_obj ["default" ]
204+ self .scope = json_obj ["scope" ]
205+
206+ options = json_obj .get ("options" )
207+ if options is not None :
208+ self .options = options
209+
210+ min_val = json_obj .get ("min" )
211+ if min_val is not None :
212+ self .min = min_val
213+
214+ max_val = json_obj .get ("max" )
215+ if max_val is not None :
216+ self .max = max_val
217+
218+ def to_json (self ) -> Dict [str , Any ]:
219+ """Returns json representation of JumpStartHyperparameter object."""
220+ json_obj = {att : getattr (self , att ) for att in self .__slots__ if hasattr (self , att )}
221+ return json_obj
222+
223+
224+ class JumpStartEnvironmentVariable (JumpStartDataHolderType ):
225+ """Data class for JumpStart environment variable."""
226+
227+ __slots__ = {
228+ "name" ,
229+ "type" ,
230+ "default" ,
231+ "scope" ,
232+ }
233+
234+ def __init__ (self , spec : Dict [str , Any ]):
235+ """Initializes a JumpStartEnvironmentVariable object from its json representation.
236+
237+ Args:
238+ spec (Dict[str, Any]): Dictionary representation of environment variable.
239+ """
240+ self .from_json (spec )
241+
242+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
243+ """Sets fields in object based on json.
244+
245+ Args:
246+ json_obj (Dict[str, Any]): Dictionary representation of environment variable.
247+ """
248+
249+ self .name = json_obj ["name" ]
250+ self .type = json_obj ["type" ]
251+ self .default = json_obj ["default" ]
252+ self .scope = json_obj ["scope" ]
253+
254+ def to_json (self ) -> Dict [str , Any ]:
255+ """Returns json representation of JumpStartEnvironmentVariable object."""
256+ json_obj = {att : getattr (self , att ) for att in self .__slots__ if hasattr (self , att )}
257+ return json_obj
258+
259+
173260class JumpStartModelSpecs (JumpStartDataHolderType ):
174261 """Data class JumpStart model specs."""
175262
@@ -186,6 +273,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
186273 "training_artifact_key" ,
187274 "training_script_key" ,
188275 "hyperparameters" ,
276+ "inference_environment_variables" ,
189277 ]
190278
191279 def __init__ (self , spec : Dict [str , Any ]):
@@ -210,22 +298,37 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
210298 self .hosting_artifact_key : str = json_obj ["hosting_artifact_key" ]
211299 self .hosting_script_key : str = json_obj ["hosting_script_key" ]
212300 self .training_supported : bool = bool (json_obj ["training_supported" ])
301+ self .inference_environment_variables = [
302+ JumpStartEnvironmentVariable (env_variable )
303+ for env_variable in json_obj ["inference_environment_variables" ]
304+ ]
213305 if self .training_supported :
214306 self .training_ecr_specs : JumpStartECRSpecs = JumpStartECRSpecs (
215307 json_obj ["training_ecr_specs" ]
216308 )
217309 self .training_artifact_key : str = json_obj ["training_artifact_key" ]
218310 self .training_script_key : str = json_obj ["training_script_key" ]
219- self .hyperparameters : Dict [str , Any ] = json_obj .get ("hyperparameters" , {})
311+ hyperparameters = json_obj .get ("hyperparameters" )
312+ if hyperparameters is not None :
313+ self .hyperparameters = [
314+ JumpStartHyperparameter (hyperparameter ) for hyperparameter in hyperparameters
315+ ]
220316
221317 def to_json (self ) -> Dict [str , Any ]:
222318 """Returns json representation of JumpStartModelSpecs object."""
223319 json_obj = {}
224320 for att in self .__slots__ :
225321 if hasattr (self , att ):
226322 cur_val = getattr (self , att )
227- if isinstance ( cur_val , JumpStartECRSpecs ):
323+ if issubclass ( type ( cur_val ), JumpStartDataHolderType ):
228324 json_obj [att ] = cur_val .to_json ()
325+ elif isinstance (cur_val , list ):
326+ json_obj [att ] = []
327+ for obj in cur_val :
328+ if issubclass (type (obj ), JumpStartDataHolderType ):
329+ json_obj [att ].append (obj .to_json ())
330+ else :
331+ json_obj [att ].append (obj )
229332 else :
230333 json_obj [att ] = cur_val
231334 return json_obj
0 commit comments