|
29 | 29 | ) |
30 | 30 | from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base |
31 | 31 | from sagemaker.jumpstart.constants import ( |
| 32 | + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
32 | 33 | INFERENCE_ENTRY_POINT_SCRIPT_NAME, |
33 | 34 | JUMPSTART_DEFAULT_REGION_NAME, |
34 | 35 | JUMPSTART_LOGGER, |
|
54 | 55 | add_jumpstart_model_info_tags, |
55 | 56 | get_default_jumpstart_session_with_user_agent_suffix, |
56 | 57 | get_neo_content_bucket, |
| 58 | + get_top_ranked_config_name, |
57 | 59 | update_dict_if_key_not_present, |
58 | 60 | resolve_model_sagemaker_config_field, |
59 | 61 | verify_model_region_and_return_specs, |
@@ -155,15 +157,18 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni |
155 | 157 | return kwargs |
156 | 158 |
|
157 | 159 |
|
158 | | -def _add_sagemaker_session_to_kwargs( |
| 160 | +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( |
159 | 161 | kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] |
160 | 162 | ) -> JumpStartModelInitKwargs: |
161 | 163 | """Sets session in kwargs based on default or override, returns full kwargs.""" |
162 | 164 |
|
163 | 165 | kwargs.sagemaker_session = ( |
164 | 166 | kwargs.sagemaker_session |
165 | 167 | or get_default_jumpstart_session_with_user_agent_suffix( |
166 | | - kwargs.model_id, kwargs.model_version, kwargs.hub_arn |
| 168 | + model_id=kwargs.model_id, |
| 169 | + model_version=kwargs.model_version, |
| 170 | + config_name=kwargs.config_name, |
| 171 | + is_hub_content=kwargs.hub_arn is not None, |
167 | 172 | ) |
168 | 173 | ) |
169 | 174 |
|
@@ -244,6 +249,32 @@ def _add_instance_type_to_kwargs( |
244 | 249 | kwargs.instance_type, |
245 | 250 | ) |
246 | 251 |
|
| 252 | + specs = verify_model_region_and_return_specs( |
| 253 | + model_id=kwargs.model_id, |
| 254 | + version=kwargs.model_version, |
| 255 | + scope=JumpStartScriptScope.INFERENCE, |
| 256 | + region=kwargs.region, |
| 257 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 258 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 259 | + sagemaker_session=kwargs.sagemaker_session, |
| 260 | + model_type=kwargs.model_type, |
| 261 | + config_name=kwargs.config_name, |
| 262 | + ) |
| 263 | + |
| 264 | + if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: |
| 265 | + return kwargs |
| 266 | + |
| 267 | + resolved_config = ( |
| 268 | + specs.inference_configs.configs[kwargs.config_name].resolved_config |
| 269 | + if specs.inference_configs |
| 270 | + else None |
| 271 | + ) |
| 272 | + if resolved_config is None: |
| 273 | + return kwargs |
| 274 | + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) |
| 275 | + if kwargs.instance_type not in supported_instance_types: |
| 276 | + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) |
| 277 | + |
247 | 278 | return kwargs |
248 | 279 |
|
249 | 280 |
|
@@ -662,38 +693,25 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta |
662 | 693 | ValueError: If the instance_type is not supported with the current config. |
663 | 694 | """ |
664 | 695 |
|
665 | | - specs = verify_model_region_and_return_specs( |
| 696 | + # we need to create a default JS session (without custom user agent) |
| 697 | + # in order to retrieve config name info |
| 698 | + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
| 699 | + |
| 700 | + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( |
| 701 | + region=kwargs.region, |
666 | 702 | model_id=kwargs.model_id, |
667 | | - version=kwargs.model_version, |
| 703 | + model_version=kwargs.model_version, |
| 704 | + sagemaker_session=temp_session, |
668 | 705 | scope=JumpStartScriptScope.INFERENCE, |
669 | | - region=kwargs.region, |
670 | | - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
671 | | - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
672 | | - sagemaker_session=kwargs.sagemaker_session, |
673 | 706 | model_type=kwargs.model_type, |
674 | | - config_name=kwargs.config_name, |
| 707 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 708 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
675 | 709 | hub_arn=kwargs.hub_arn, |
676 | 710 | ) |
677 | | - if specs.inference_configs: |
678 | | - default_config_name = ( |
679 | | - specs.inference_configs.get_top_config_from_ranking().config_name |
680 | | - if specs.inference_configs.get_top_config_from_ranking() |
681 | | - else None |
682 | | - ) |
683 | | - kwargs.config_name = kwargs.config_name or default_config_name |
684 | | - |
685 | | - if not kwargs.config_name: |
686 | | - return kwargs |
687 | 711 |
|
688 | | - if kwargs.config_name not in set(specs.inference_configs.configs.keys()): |
689 | | - raise ValueError( |
690 | | - f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." |
691 | | - ) |
| 712 | + if kwargs.config_name is None: |
| 713 | + return kwargs |
692 | 714 |
|
693 | | - resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config |
694 | | - supported_instance_types = resolved_config.get("supported_inference_instance_types", []) |
695 | | - if kwargs.instance_type not in supported_instance_types: |
696 | | - JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) |
697 | 715 | return kwargs |
698 | 716 |
|
699 | 717 |
|
@@ -746,32 +764,41 @@ def _add_config_name_to_deploy_kwargs( |
746 | 764 | ValueError: If the instance_type is not supported with the current config. |
747 | 765 | """ |
748 | 766 |
|
749 | | - specs = verify_model_region_and_return_specs( |
750 | | - model_id=kwargs.model_id, |
751 | | - version=kwargs.model_version, |
752 | | - scope=JumpStartScriptScope.INFERENCE, |
753 | | - region=kwargs.region, |
754 | | - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
755 | | - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
756 | | - sagemaker_session=kwargs.sagemaker_session, |
757 | | - model_type=kwargs.model_type, |
758 | | - config_name=kwargs.config_name, |
759 | | - hub_arn=kwargs.hub_arn, |
760 | | - ) |
| 767 | + # we need to create a default JS session (without custom user agent) |
| 768 | + # in order to retrieve config name info |
| 769 | + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
761 | 770 |
|
762 | 771 | if training_config_name: |
763 | | - kwargs.config_name = _select_inference_config_from_training_config( |
| 772 | + |
| 773 | + specs = verify_model_region_and_return_specs( |
| 774 | + model_id=kwargs.model_id, |
| 775 | + version=kwargs.model_version, |
| 776 | + scope=JumpStartScriptScope.INFERENCE, |
| 777 | + region=kwargs.region, |
| 778 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 779 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 780 | + sagemaker_session=temp_session, |
| 781 | + model_type=kwargs.model_type, |
| 782 | + config_name=kwargs.config_name, |
| 783 | + ) |
| 784 | + default_config_name = _select_inference_config_from_training_config( |
764 | 785 | specs=specs, training_config_name=training_config_name |
765 | 786 | ) |
766 | | - return kwargs |
767 | 787 |
|
768 | | - if specs.inference_configs: |
769 | | - default_config_name = ( |
770 | | - specs.inference_configs.get_top_config_from_ranking().config_name |
771 | | - if specs.inference_configs.get_top_config_from_ranking() |
772 | | - else None |
| 788 | + else: |
| 789 | + default_config_name = get_top_ranked_config_name( |
| 790 | + region=kwargs.region, |
| 791 | + model_id=kwargs.model_id, |
| 792 | + model_version=kwargs.model_version, |
| 793 | + sagemaker_session=temp_session, |
| 794 | + scope=JumpStartScriptScope.INFERENCE, |
| 795 | + model_type=kwargs.model_type, |
| 796 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 797 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 798 | + hub_arn=kwargs.hub_arn, |
773 | 799 | ) |
774 | | - kwargs.config_name = kwargs.config_name or default_config_name |
| 800 | + |
| 801 | + kwargs.config_name = kwargs.config_name or default_config_name |
775 | 802 |
|
776 | 803 | return kwargs |
777 | 804 |
|
@@ -850,15 +877,15 @@ def get_deploy_kwargs( |
850 | 877 | routing_config=routing_config, |
851 | 878 | ) |
852 | 879 |
|
853 | | - deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) |
| 880 | + deploy_kwargs = _add_config_name_to_deploy_kwargs( |
| 881 | + kwargs=deploy_kwargs, training_config_name=training_config_name |
| 882 | + ) |
854 | 883 |
|
855 | 884 | deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) |
856 | 885 |
|
857 | | - deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) |
| 886 | + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) |
858 | 887 |
|
859 | | - deploy_kwargs = _add_config_name_to_deploy_kwargs( |
860 | | - kwargs=deploy_kwargs, training_config_name=training_config_name |
861 | | - ) |
| 888 | + deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) |
862 | 889 |
|
863 | 890 | deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) |
864 | 891 |
|
@@ -1041,11 +1068,14 @@ def get_init_kwargs( |
1041 | 1068 | ) |
1042 | 1069 |
|
1043 | 1070 | model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) |
| 1071 | + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) |
| 1072 | + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) |
1044 | 1073 |
|
1045 | | - model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) |
| 1074 | + model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( |
| 1075 | + kwargs=model_init_kwargs |
| 1076 | + ) |
1046 | 1077 | model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) |
1047 | 1078 |
|
1048 | | - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) |
1049 | 1079 | model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) |
1050 | 1080 |
|
1051 | 1081 | model_init_kwargs = _add_instance_type_to_kwargs( |
@@ -1073,8 +1103,6 @@ def get_init_kwargs( |
1073 | 1103 |
|
1074 | 1104 | model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) |
1075 | 1105 |
|
1076 | | - model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) |
1077 | | - |
1078 | 1106 | model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) |
1079 | 1107 |
|
1080 | 1108 | return model_init_kwargs |
0 commit comments