2828 verify_model_region_and_return_specs ,
2929)
3030from sagemaker .session import Session
31+ from sagemaker .jumpstart .types import JumpStartModelSpecs
32+
33+
34+ def _retrieve_hosting_prepacked_artifact_key (
35+ model_specs : JumpStartModelSpecs , instance_type : str
36+ ) -> str :
37+ """Returns instance specific hosting prepacked artifact key or default one as fallback."""
38+ instance_specific_prepacked_hosting_artifact_key : Optional [str ] = (
39+ model_specs .hosting_instance_type_variants .get_instance_specific_prepacked_artifact_key (
40+ instance_type = instance_type
41+ )
42+ if instance_type
43+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
44+ else None
45+ )
46+
47+ default_prepacked_hosting_artifact_key : Optional [str ] = getattr (
48+ model_specs , "hosting_prepacked_artifact_key"
49+ )
50+
51+ return (
52+ instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
53+ )
54+
55+
56+ def _retrieve_hosting_artifact_key (model_specs : JumpStartModelSpecs , instance_type : str ) -> str :
57+ """Returns instance specific hosting artifact key or default one as fallback."""
58+ instance_specific_hosting_artifact_key : Optional [str ] = (
59+ model_specs .hosting_instance_type_variants .get_instance_specific_artifact_key (
60+ instance_type = instance_type
61+ )
62+ if instance_type
63+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
64+ else None
65+ )
66+
67+ default_hosting_artifact_key : str = model_specs .hosting_artifact_key
68+
69+ return instance_specific_hosting_artifact_key or default_hosting_artifact_key
70+
71+
72+ def _retrieve_training_artifact_key (model_specs : JumpStartModelSpecs , instance_type : str ) -> str :
73+ """Returns instance specific training artifact key or default one as fallback."""
74+ instance_specific_training_artifact_key : Optional [str ] = (
75+ model_specs .training_instance_type_variants .get_instance_specific_artifact_key (
76+ instance_type = instance_type
77+ )
78+ if instance_type
79+ and getattr (model_specs , "training_instance_type_variants" , None ) is not None
80+ else None
81+ )
82+
83+ default_training_artifact_key : str = model_specs .training_artifact_key
84+
85+ return instance_specific_training_artifact_key or default_training_artifact_key
3186
3287
3388def _retrieve_model_uri (
3489 model_id : str ,
3590 model_version : str ,
3691 model_scope : Optional [str ] = None ,
92+ instance_type : Optional [str ] = None ,
3793 region : Optional [str ] = None ,
3894 tolerate_vulnerable_model : bool = False ,
3995 tolerate_deprecated_model : bool = False ,
@@ -50,6 +106,7 @@ def _retrieve_model_uri(
50106 artifact S3 URI.
51107 model_scope (str): The model type, i.e. what it is used for.
52108 Valid values: "training" and "inference".
109+ instance_type (str): The ML compute instance type for the specified scope. (Default: None).
53110 region (str): Region for which to retrieve model S3 URI. (Default: None).
54111 tolerate_vulnerable_model (bool): True if vulnerable versions of model
55112 specifications should be tolerated (exception not raised). If False, raises an
@@ -84,14 +141,21 @@ def _retrieve_model_uri(
84141 sagemaker_session = sagemaker_session ,
85142 )
86143
144+ model_artifact_key : str
145+
87146 if model_scope == JumpStartScriptScope .INFERENCE :
147+
148+ is_prepacked = not model_specs .use_inference_script_uri ()
149+
88150 model_artifact_key = (
89- getattr (model_specs , "hosting_prepacked_artifact_key" , None )
90- or model_specs .hosting_artifact_key
151+ _retrieve_hosting_prepacked_artifact_key (model_specs , instance_type )
152+ if is_prepacked
153+ else _retrieve_hosting_artifact_key (model_specs , instance_type )
91154 )
92155
93156 elif model_scope == JumpStartScriptScope .TRAINING :
94- model_artifact_key = model_specs .training_artifact_key
157+
158+ model_artifact_key = _retrieve_training_artifact_key (model_specs , instance_type )
95159
96160 bucket = os .environ .get (
97161 ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
0 commit comments