|
29 | 29 | )
|
30 | 30 | from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
|
31 | 31 | from sagemaker.jumpstart.constants import (
|
| 32 | + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
32 | 33 | INFERENCE_ENTRY_POINT_SCRIPT_NAME,
|
33 | 34 | JUMPSTART_DEFAULT_REGION_NAME,
|
34 | 35 | JUMPSTART_LOGGER,
|
|
54 | 55 | add_jumpstart_model_info_tags,
|
55 | 56 | get_default_jumpstart_session_with_user_agent_suffix,
|
56 | 57 | get_neo_content_bucket,
|
| 58 | + get_top_ranked_config_name, |
57 | 59 | update_dict_if_key_not_present,
|
58 | 60 | resolve_model_sagemaker_config_field,
|
59 | 61 | verify_model_region_and_return_specs,
|
@@ -155,15 +157,15 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
|
155 | 157 | return kwargs
|
156 | 158 |
|
157 | 159 |
|
158 |
| -def _add_sagemaker_session_to_kwargs( |
| 160 | +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( |
159 | 161 | kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
|
160 | 162 | ) -> JumpStartModelInitKwargs:
|
161 | 163 | """Sets session in kwargs based on default or override, returns full kwargs."""
|
162 | 164 |
|
163 | 165 | kwargs.sagemaker_session = (
|
164 | 166 | kwargs.sagemaker_session
|
165 | 167 | or get_default_jumpstart_session_with_user_agent_suffix(
|
166 |
| - kwargs.model_id, kwargs.model_version, kwargs.hub_arn |
| 168 | + kwargs.model_id, kwargs.model_version, kwargs.config_name, kwargs.hub_arn |
167 | 169 | )
|
168 | 170 | )
|
169 | 171 |
|
@@ -662,33 +664,47 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
|
662 | 664 | ValueError: If the instance_type is not supported with the current config.
|
663 | 665 | """
|
664 | 666 |
|
| 667 | + # we need to create a default JS session (without custom user agent) |
| 668 | + # in order to retrieve config name info |
| 669 | + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
| 670 | + |
| 671 | + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( |
| 672 | + region=kwargs.region, |
| 673 | + model_id=kwargs.model_id, |
| 674 | + model_version=kwargs.model_version, |
| 675 | + sagemaker_session=temp_session, |
| 676 | + scope=JumpStartScriptScope.INFERENCE, |
| 677 | + model_type=kwargs.model_type, |
| 678 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 679 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 680 | + hub_arn=kwargs.hub_arn, |
| 681 | + ) |
| 682 | + |
| 683 | + if kwargs.config_name is None: |
| 684 | + return kwargs |
| 685 | + |
665 | 686 | specs = verify_model_region_and_return_specs(
|
666 | 687 | model_id=kwargs.model_id,
|
667 | 688 | version=kwargs.model_version,
|
668 | 689 | scope=JumpStartScriptScope.INFERENCE,
|
669 | 690 | region=kwargs.region,
|
670 | 691 | tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
|
671 | 692 | tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
|
672 |
| - sagemaker_session=kwargs.sagemaker_session, |
| 693 | + sagemaker_session=temp_session, |
673 | 694 | model_type=kwargs.model_type,
|
674 | 695 | config_name=kwargs.config_name,
|
675 | 696 | )
|
676 |
| - if specs.inference_configs: |
677 |
| - default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name |
678 |
| - kwargs.config_name = kwargs.config_name or default_config_name |
679 |
| - |
680 |
| - if not kwargs.config_name: |
681 |
| - return kwargs |
682 |
| - |
683 |
| - if kwargs.config_name not in set(specs.inference_configs.configs.keys()): |
684 |
| - raise ValueError( |
685 |
| - f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." |
686 |
| - ) |
687 | 697 |
|
688 |
| - resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config |
689 |
| - supported_instance_types = resolved_config.get("supported_inference_instance_types", []) |
690 |
| - if kwargs.instance_type not in supported_instance_types: |
691 |
| - JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) |
| 698 | + resolved_config = ( |
| 699 | + specs.inference_configs.configs[kwargs.config_name].resolved_config |
| 700 | + if specs.inference_configs |
| 701 | + else None |
| 702 | + ) |
| 703 | + if resolved_config is None: |
| 704 | + return kwargs |
| 705 | + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) |
| 706 | + if kwargs.instance_type not in supported_instance_types: |
| 707 | + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) |
692 | 708 | return kwargs
|
693 | 709 |
|
694 | 710 |
|
@@ -740,27 +756,41 @@ def _add_config_name_to_deploy_kwargs(
|
740 | 756 | ValueError: If the instance_type is not supported with the current config.
|
741 | 757 | """
|
742 | 758 |
|
743 |
| - specs = verify_model_region_and_return_specs( |
744 |
| - model_id=kwargs.model_id, |
745 |
| - version=kwargs.model_version, |
746 |
| - scope=JumpStartScriptScope.INFERENCE, |
747 |
| - region=kwargs.region, |
748 |
| - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
749 |
| - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
750 |
| - sagemaker_session=kwargs.sagemaker_session, |
751 |
| - model_type=kwargs.model_type, |
752 |
| - config_name=kwargs.config_name, |
753 |
| - ) |
| 759 | + # we need to create a default JS session (without custom user agent) |
| 760 | + # in order to retrieve config name info |
| 761 | + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
754 | 762 |
|
755 | 763 | if training_config_name:
|
756 |
| - kwargs.config_name = _select_inference_config_from_training_config( |
| 764 | + |
| 765 | + specs = verify_model_region_and_return_specs( |
| 766 | + model_id=kwargs.model_id, |
| 767 | + version=kwargs.model_version, |
| 768 | + scope=JumpStartScriptScope.INFERENCE, |
| 769 | + region=kwargs.region, |
| 770 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 771 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 772 | + sagemaker_session=temp_session, |
| 773 | + model_type=kwargs.model_type, |
| 774 | + config_name=kwargs.config_name, |
| 775 | + ) |
| 776 | + default_config_name = _select_inference_config_from_training_config( |
757 | 777 | specs=specs, training_config_name=training_config_name
|
758 | 778 | )
|
759 |
| - return kwargs |
760 | 779 |
|
761 |
| - if specs.inference_configs: |
762 |
| - default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name |
763 |
| - kwargs.config_name = kwargs.config_name or default_config_name |
| 780 | + else: |
| 781 | + default_config_name = get_top_ranked_config_name( |
| 782 | + region=kwargs.region, |
| 783 | + model_id=kwargs.model_id, |
| 784 | + model_version=kwargs.model_version, |
| 785 | + sagemaker_session=temp_session, |
| 786 | + scope=JumpStartScriptScope.INFERENCE, |
| 787 | + model_type=kwargs.model_type, |
| 788 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 789 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 790 | + hub_arn=kwargs.hub_arn, |
| 791 | + ) |
| 792 | + |
| 793 | + kwargs.config_name = kwargs.config_name or default_config_name |
764 | 794 |
|
765 | 795 | return kwargs
|
766 | 796 |
|
@@ -839,16 +869,16 @@ def get_deploy_kwargs(
|
839 | 869 | routing_config=routing_config,
|
840 | 870 | )
|
841 | 871 |
|
842 |
| - deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) |
| 872 | + deploy_kwargs = _add_config_name_to_deploy_kwargs( |
| 873 | + kwargs=deploy_kwargs, training_config_name=training_config_name |
| 874 | + ) |
| 875 | + |
| 876 | + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) |
843 | 877 |
|
844 | 878 | deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
|
845 | 879 |
|
846 | 880 | deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
|
847 | 881 |
|
848 |
| - deploy_kwargs = _add_config_name_to_deploy_kwargs( |
849 |
| - kwargs=deploy_kwargs, training_config_name=training_config_name |
850 |
| - ) |
851 |
| - |
852 | 882 | deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
|
853 | 883 |
|
854 | 884 | deploy_kwargs.initial_instance_count = initial_instance_count or 1
|
@@ -1030,11 +1060,14 @@ def get_init_kwargs(
|
1030 | 1060 | )
|
1031 | 1061 |
|
1032 | 1062 | model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
|
| 1063 | + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) |
| 1064 | + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) |
1033 | 1065 |
|
1034 |
| - model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) |
| 1066 | + model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( |
| 1067 | + kwargs=model_init_kwargs |
| 1068 | + ) |
1035 | 1069 | model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
|
1036 | 1070 |
|
1037 |
| - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) |
1038 | 1071 | model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
|
1039 | 1072 |
|
1040 | 1073 | model_init_kwargs = _add_instance_type_to_kwargs(
|
@@ -1062,8 +1095,6 @@ def get_init_kwargs(
|
1062 | 1095 |
|
1063 | 1096 | model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
|
1064 | 1097 |
|
1065 |
| - model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) |
1066 |
| - |
1067 | 1098 | model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
|
1068 | 1099 |
|
1069 | 1100 | return model_init_kwargs
|
0 commit comments