diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 0d156c415f..4c29c3f625 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -912,6 +912,7 @@ def _add_config_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, config_name=kwargs.config_name, + hub_arn=kwargs.hub_arn, ) if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index f4e13de6d7..3f0b627d5d 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -672,9 +672,14 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, config_name=kwargs.config_name, + hub_arn=kwargs.hub_arn, ) if specs.inference_configs: - default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + default_config_name = ( + specs.inference_configs.get_top_config_from_ranking().config_name + if specs.inference_configs.get_top_config_from_ranking() + else None + ) kwargs.config_name = kwargs.config_name or default_config_name if not kwargs.config_name: @@ -707,6 +712,7 @@ def _add_additional_model_data_sources_to_kwargs( sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, config_name=kwargs.config_name, + hub_arn=kwargs.hub_arn, ) # Append speculative decoding data source from metadata speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() @@ -750,6 +756,7 @@ def _add_config_name_to_deploy_kwargs( sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, config_name=kwargs.config_name, + hub_arn=kwargs.hub_arn, ) if training_config_name: @@ -759,7 +766,11 @@ def _add_config_name_to_deploy_kwargs( return kwargs if specs.inference_configs: - default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + default_config_name = ( + specs.inference_configs.get_top_config_from_ranking().config_name + if specs.inference_configs.get_top_config_from_ranking() + else None + ) kwargs.config_name = kwargs.config_name or default_config_name return kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 0cb8bbd55a..c9354e020b 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -393,6 +393,7 @@ def _validate_model_id_and_type(): model_version=self.model_version, sagemaker_session=self.sagemaker_session, model_type=self.model_type, + hub_arn=self.hub_arn, ) def log_subscription_warning(self) -> None: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index ae54bc72b8..de1b6656b4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1277,6 +1277,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of spec. """ + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -1722,6 +1724,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ super().from_json(json_obj) + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) @@ -1732,32 +1736,50 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = ( { - alias: JumpStartConfigRanking(ranking) + alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content) for alias, ranking in json_obj["inference_config_rankings"].items() } if json_obj.get("inference_config_rankings") else None ) - inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( - { - alias: JumpStartMetadataConfig( - alias, - config, - json_obj, - ( - { - component_name: self.inference_config_components.get(component_name) - for component_name in config.get("component_names") - } - if config and config.get("component_names") - else None - ), - ) - for alias, config in json_obj["inference_configs"].items() - } - if json_obj.get("inference_configs") - else None - ) + + if self._is_hub_content: + inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + config.config_components, + is_hub_content=self._is_hub_content, + ) + for alias, config in json_obj["inference_configs"]["configs"].items() + } + if json_obj.get("inference_configs") + else None + ) + else: + inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + ( + { + component_name: self.inference_config_components.get(component_name) + for component_name in config.get("component_names") + } + if config and config.get("component_names") + else None + ), + ) + for alias, config in json_obj["inference_configs"].items() + } + if json_obj.get("inference_configs") + else None + ) + self.inference_configs: Optional[JumpStartMetadataConfigs] = ( JumpStartMetadataConfigs( inference_configs_dict, @@ -1784,26 +1806,45 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("training_config_rankings") else None ) - training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( - { - alias: JumpStartMetadataConfig( - alias, - config, - json_obj, - ( - { - component_name: self.training_config_components.get(component_name) - for component_name in config.get("component_names") - } - if config and config.get("component_names") - else None - ), - ) - for alias, config in json_obj["training_configs"].items() - } - if json_obj.get("training_configs") - else None - ) + + if self._is_hub_content: + training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + config.config_components, + is_hub_content=self._is_hub_content, + ) + for alias, config in json_obj["training_configs"]["configs"].items() + } + if json_obj.get("training_configs") + else None + ) + else: + training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + ( + { + component_name: self.training_config_components.get( + component_name + ) + for component_name in config.get("component_names") + } + if config and config.get("component_names") + else None + ), + ) + for alias, config in json_obj["training_configs"].items() + } + if json_obj.get("training_configs") + else None + ) self.training_configs: Optional[JumpStartMetadataConfigs] = ( JumpStartMetadataConfigs( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f521dbcc5a..014f60ae8a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1074,6 +1074,7 @@ def get_jumpstart_configs( sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, ) -> Dict[str, JumpStartMetadataConfig]: """Returns metadata configs for the given model ID and region. @@ -1087,6 +1088,7 @@ def get_jumpstart_configs( sagemaker_session=sagemaker_session, scope=scope, model_type=model_type, + hub_arn=hub_arn, ) if scope == enums.JumpStartScriptScope.INFERENCE: