@@ -744,12 +744,12 @@ def _get_regional_property(
744744
745745
746746class JumpStartBenchmarkStat (JumpStartDataHolderType ):
747- """Data class JumpStart benchmark stats ."""
747+ """Data class JumpStart benchmark stat ."""
748748
749749 __slots__ = ["name" , "value" , "unit" ]
750750
751751 def __init__ (self , spec : Dict [str , Any ]):
752- """Initializes a JumpStartBenchmarkStat object
752+ """Initializes a JumpStartBenchmarkStat object.
753753
754754 Args:
755755 spec (Dict[str, Any]): Dictionary representation of benchmark stat.
@@ -858,7 +858,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
858858 "model_subscription_link" ,
859859 ]
860860
861- def __init__ (self , fields : Optional [ Dict [str , Any ] ]):
861+ def __init__ (self , fields : Dict [str , Any ]):
862862 """Initializes a JumpStartMetadataFields object.
863863
864864 Args:
@@ -877,7 +877,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
877877 self .version : str = json_obj .get ("version" )
878878 self .min_sdk_version : str = json_obj .get ("min_sdk_version" )
879879 self .incremental_training_supported : bool = bool (
880- json_obj .get ("incremental_training_supported" )
880+ json_obj .get ("incremental_training_supported" , False )
881881 )
882882 self .hosting_ecr_specs : Optional [JumpStartECRSpecs ] = (
883883 JumpStartECRSpecs (json_obj ["hosting_ecr_specs" ])
@@ -1038,7 +1038,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
10381038
10391039 __slots__ = slots + JumpStartMetadataBaseFields .__slots__
10401040
1041- def __init__ ( # pylint: disable=super-init-not-called
1041+ def __init__ (
10421042 self ,
10431043 component_name : str ,
10441044 component : Optional [Dict [str , Any ]],
@@ -1049,7 +1049,10 @@ def __init__( # pylint: disable=super-init-not-called
10491049 component_name (str): Name of the component.
10501050 component (Dict[str, Any]):
10511051 Dictionary representation of the config component.
1052+ Raises:
1053+ ValueError: If the component field is invalid.
10521054 """
1055+ super ().__init__ (component )
10531056 self .component_name = component_name
10541057 self .from_json (component )
10551058
@@ -1080,7 +1083,7 @@ def __init__(
10801083 self ,
10811084 base_fields : Dict [str , Any ],
10821085 config_components : Dict [str , JumpStartConfigComponent ],
1083- benchmark_metrics : Dict [str , JumpStartBenchmarkStat ],
1086+ benchmark_metrics : Dict [str , List [ JumpStartBenchmarkStat ] ],
10841087 ):
10851088 """Initializes a JumpStartMetadataConfig object from its json representation.
10861089
@@ -1089,12 +1092,12 @@ def __init__(
10891092 The default base fields that are used to construct the final resolved config.
10901093 config_components (Dict[str, JumpStartConfigComponent]):
10911094 The list of components that are used to construct the resolved config.
1092- benchmark_metrics (Dict[str, JumpStartBenchmarkStat]):
1095+ benchmark_metrics (Dict[str, List[ JumpStartBenchmarkStat] ]):
10931096 The dictionary of benchmark metrics with name being the key.
10941097 """
10951098 self .base_fields = base_fields
10961099 self .config_components : Dict [str , JumpStartConfigComponent ] = config_components
1097- self .benchmark_metrics : Dict [str , JumpStartBenchmarkStat ] = benchmark_metrics
1100+ self .benchmark_metrics : Dict [str , List [ JumpStartBenchmarkStat ] ] = benchmark_metrics
10981101 self .resolved_metadata_config : Optional [Dict [str , Any ]] = None
10991102
11001103 def to_json (self ) -> Dict [str , Any ]:
@@ -1104,7 +1107,7 @@ def to_json(self) -> Dict[str, Any]:
11041107
11051108 @property
11061109 def resolved_config (self ) -> Dict [str , Any ]:
1107- """Returns the final config that is resolved from the list of components.
1110+ """Returns the final config that is resolved from the components map .
11081111
11091112 Construct the final config by applying the list of configs from list index,
11101113 and apply to the base default fields in the current model specs.
@@ -1139,7 +1142,7 @@ def __init__(
11391142
11401143 Args:
11411144 configs (Dict[str, JumpStartMetadataConfig]):
1142- List of configs that the current model has .
1145+ The map of JumpStartMetadataConfig object, with config name being the key .
11431146 config_rankings (JumpStartConfigRanking):
11441147 Config ranking class represents the ranking of the configs in the model.
11451148 scope (JumpStartScriptScope):
@@ -1158,19 +1161,30 @@ def get_top_config_from_ranking(
11581161 self ,
11591162 ranking_name : str = JumpStartConfigRankingName .DEFAULT ,
11601163 instance_type : Optional [str ] = None ,
1161- ) -> JumpStartMetadataConfig :
1162- """Gets the best the config based on config ranking."""
1164+ ) -> Optional [JumpStartMetadataConfig ]:
1165+ """Gets the best the config based on config ranking.
1166+
1167+ Args:
1168+ ranking_name (str):
1169+ The ranking name that config priority is based on.
1170+ instance_type (Optional[str]):
1171+ The instance type which the config selection is based on.
1172+
1173+ Raises:
1174+ ValueError: If the config exists but missing config ranking.
1175+ NotImplementedError: If the scope is unrecognized.
1176+ """
11631177 if self .configs and (
11641178 not self .config_rankings or not self .config_rankings .get (ranking_name )
11651179 ):
1166- raise ValueError ("Config exists but missing config ranking." )
1180+ raise ValueError (f "Config exists but missing config ranking { ranking_name } ." )
11671181
11681182 if self .scope == JumpStartScriptScope .INFERENCE :
11691183 instance_type_attribute = "supported_inference_instance_types"
11701184 elif self .scope == JumpStartScriptScope .TRAINING :
11711185 instance_type_attribute = "supported_training_instance_types"
11721186 else :
1173- raise ValueError (f"Unknown script scope { self .scope } " )
1187+ raise NotImplementedError (f"Unknown script scope { self .scope } " )
11741188
11751189 rankings = self .config_rankings .get (ranking_name )
11761190 for config_name in rankings .rankings :
@@ -1198,12 +1212,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields):
11981212
11991213 __slots__ = JumpStartMetadataBaseFields .__slots__ + slots
12001214
1201- def __init__ (self , spec : Dict [str , Any ]): # pylint: disable=super-init-not-called
1215+ def __init__ (self , spec : Dict [str , Any ]):
12021216 """Initializes a JumpStartModelSpecs object from its json representation.
12031217
12041218 Args:
12051219 spec (Dict[str, Any]): Dictionary representation of spec.
12061220 """
1221+ super ().__init__ (spec )
12071222 self .from_json (spec )
12081223 if self .inference_configs and self .inference_configs .get_top_config_from_ranking ():
12091224 super ().from_json (self .inference_configs .get_top_config_from_ranking ().resolved_config )
@@ -1245,8 +1260,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12451260 ),
12461261 (
12471262 {
1248- stat_name : JumpStartBenchmarkStat (stat )
1249- for stat_name , stat in config .get ("benchmark_metrics" ).items ()
1263+ stat_name : [ JumpStartBenchmarkStat (stat ) for stat in stats ]
1264+ for stat_name , stats in config .get ("benchmark_metrics" ).items ()
12501265 }
12511266 if config and config .get ("benchmark_metrics" )
12521267 else None
@@ -1297,8 +1312,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12971312 ),
12981313 (
12991314 {
1300- stat_name : JumpStartBenchmarkStat (stat )
1301- for stat_name , stat in config .get ("benchmark_metrics" ).items ()
1315+ stat_name : [ JumpStartBenchmarkStat (stat ) for stat in stats ]
1316+ for stat_name , stats in config .get ("benchmark_metrics" ).items ()
13021317 }
13031318 if config and config .get ("benchmark_metrics" )
13041319 else None
@@ -1330,13 +1345,26 @@ def set_config(
13301345 config_name (str): Name of the config.
13311346 scope (JumpStartScriptScope, optional):
13321347 Scope of the config. Defaults to JumpStartScriptScope.INFERENCE.
1348+
1349+ Raises:
1350+ ValueError: If the scope is not supported, or cannot find config name.
13331351 """
13341352 if scope == JumpStartScriptScope .INFERENCE :
1335- super (). from_json ( self .inference_configs . configs [ config_name ]. resolved_config )
1353+ metadata_configs = self .inference_configs
13361354 elif scope == JumpStartScriptScope .TRAINING and self .training_supported :
1337- super (). from_json ( self .training_configs . configs [ config_name ]. resolved_config )
1355+ metadata_configs = self .training_configs
13381356 else :
1339- raise ValueError (f"Unknown Jumpstart Script scope { scope } ." )
1357+ raise ValueError (f"Unknown Jumpstart script scope { scope } ." )
1358+
1359+ config_object = metadata_configs .configs .get (config_name )
1360+ if not config_object :
1361+ error_msg = f"Cannot find Jumpstart config name { config_name } . "
1362+ config_names = list (metadata_configs .configs .keys ())
1363+ if config_names :
1364+ error_msg += f"List of config names that is supported by the model: { config_names } "
1365+ raise ValueError (error_msg )
1366+
1367+ super ().from_json (config_object .resolved_config )
13401368
13411369 def supports_prepacked_inference (self ) -> bool :
13421370 """Returns True if the model has a prepacked inference artifact."""
0 commit comments