1717 get_prototype_manifest ,
1818 get_prototype_model_spec ,
1919)
20+ from tests .unit .sagemaker .jumpstart .constants import BASE_PROPRIETARY_MANIFEST
2021from sagemaker .jumpstart .enums import JumpStartModelType
2122from sagemaker .jumpstart .notebook_utils import (
2223 _generate_jumpstart_model_versions ,
@@ -40,8 +41,8 @@ def test_list_jumpstart_scripts(
4041 patched_read_s3_file : Mock ,
4142):
4243 patched_get_model_specs .side_effect = get_prototype_model_spec
43- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
44- region
44+ patched_get_manifest .side_effect = (
45+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
4546 )
4647 patched_generate_jumpstart_models .side_effect = _generate_jumpstart_model_versions
4748 patched_read_s3_file .side_effect = lambda * args , ** kwargs : json .dumps (
@@ -63,7 +64,9 @@ def test_list_jumpstart_scripts(
6364 }
6465 assert list_jumpstart_scripts (** kwargs ) == sorted (["inference" , "training" ])
6566 patched_generate_jumpstart_models .assert_called_once_with (
66- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
67+ ** kwargs ,
68+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
69+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
6770 )
6871 assert patched_get_manifest .call_count == 2
6972 assert patched_get_model_specs .call_count == 1
@@ -76,12 +79,15 @@ def test_list_jumpstart_scripts(
7679 "filter" : "training_supported is False" ,
7780 "region" : "sa-east-1" ,
7881 }
82+ num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
7983 assert list_jumpstart_scripts (** kwargs ) == []
8084 patched_generate_jumpstart_models .assert_called_once_with (
81- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
85+ ** kwargs ,
86+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
87+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
8288 )
8389 assert patched_get_manifest .call_count == 2
84- assert patched_read_s3_file .call_count == 2 * len ( PROTOTYPICAL_MODEL_SPECS_DICT )
90+ assert patched_read_s3_file .call_count == num_specs
8591
8692
8793@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
@@ -93,8 +99,8 @@ def test_list_jumpstart_tasks(
9399 patched_get_manifest : Mock ,
94100):
95101 patched_get_model_specs .side_effect = get_prototype_model_spec
96- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
97- region
102+ patched_get_manifest .side_effect = (
103+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
98104 )
99105 patched_generate_jumpstart_models .side_effect = _generate_jumpstart_model_versions
100106
@@ -122,7 +128,9 @@ def test_list_jumpstart_tasks(
122128 }
123129 assert list_jumpstart_tasks (** kwargs ) == ["ic" ]
124130 patched_generate_jumpstart_models .assert_called_once_with (
125- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
131+ ** kwargs ,
132+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
133+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
126134 )
127135 assert patched_get_manifest .call_count == 2
128136 patched_get_model_specs .assert_not_called ()
@@ -137,8 +145,8 @@ def test_list_jumpstart_frameworks(
137145 patched_get_manifest : Mock ,
138146):
139147 patched_get_model_specs .side_effect = get_prototype_model_spec
140- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
141- region
148+ patched_get_manifest .side_effect = (
149+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
142150 )
143151 patched_generate_jumpstart_models .side_effect = _generate_jumpstart_model_versions
144152
@@ -180,7 +188,9 @@ def test_list_jumpstart_frameworks(
180188 )
181189
182190 patched_generate_jumpstart_models .assert_called_once_with (
183- ** kwargs , sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
191+ ** kwargs ,
192+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
193+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
184194 )
185195 assert patched_get_manifest .call_count == 4
186196 patched_get_model_specs .assert_not_called ()
@@ -229,8 +239,8 @@ def test_list_jumpstart_models_script_filter(
229239 patched_read_s3_file .side_effect = lambda * args , ** kwargs : json .dumps (
230240 get_prototype_model_spec (None , "pytorch-eqa-bert-base-cased" ).to_json ()
231241 )
232- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
233- region
242+ patched_get_manifest .side_effect = (
243+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
234244 )
235245
236246 manifest_length = len (get_prototype_manifest ())
@@ -516,8 +526,8 @@ def test_list_jumpstart_models_vulnerable_models(
516526 patched_get_manifest : Mock ,
517527 ):
518528
519- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
520- region
529+ patched_get_manifest .side_effect = (
530+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
521531 )
522532
523533 def vulnerable_inference_model_spec (bucket , key , * args , ** kwargs ) -> str :
@@ -533,11 +543,12 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
533543 patched_read_s3_file .side_effect = vulnerable_inference_model_spec
534544
535545 num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
546+ num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
536547 assert [] == list_jumpstart_models (
537548 And ("inference_vulnerable is false" , "training_vulnerable is false" )
538549 )
539550
540- assert patched_read_s3_file .call_count == 2 * num_specs
551+ assert patched_read_s3_file .call_count == num_specs + num_prop_specs
541552 assert patched_get_manifest .call_count == 2
542553
543554 patched_get_manifest .reset_mock ()
@@ -549,7 +560,7 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
549560 And ("inference_vulnerable is false" , "training_vulnerable is false" )
550561 )
551562
552- assert patched_read_s3_file .call_count == 2 * num_specs
563+ assert patched_read_s3_file .call_count == num_specs + num_prop_specs
553564 assert patched_get_manifest .call_count == 2
554565
555566 patched_get_manifest .reset_mock ()
@@ -567,8 +578,8 @@ def test_list_jumpstart_models_deprecated_models(
567578 patched_get_manifest : Mock ,
568579 ):
569580
570- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
571- region
581+ patched_get_manifest .side_effect = (
582+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
572583 )
573584
574585 def deprecated_model_spec (bucket , key , * args , ** kwargs ) -> str :
@@ -579,9 +590,10 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
579590 patched_read_s3_file .side_effect = deprecated_model_spec
580591
581592 num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
593+ num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
582594 assert [] == list_jumpstart_models ("deprecated equals false" )
583595
584- assert patched_read_s3_file .call_count == 2 * num_specs
596+ assert patched_read_s3_file .call_count == num_specs + num_prop_specs
585597 assert patched_get_manifest .call_count == 2
586598
587599 patched_get_manifest .reset_mock ()
@@ -666,8 +678,8 @@ def test_list_jumpstart_models_complex_queries(
666678 patched_read_s3_file .side_effect = lambda * args , ** kwargs : json .dumps (
667679 get_prototype_model_spec (None , "pytorch-eqa-bert-base-cased" ).to_json ()
668680 )
669- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
670- region
681+ patched_get_manifest .side_effect = (
682+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
671683 )
672684
673685 assert list_jumpstart_models (
@@ -711,8 +723,8 @@ def test_list_jumpstart_models_multiple_level_index(
711723 patched_get_manifest : Mock ,
712724 ):
713725 patched_get_model_specs .side_effect = get_prototype_model_spec
714- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
715- region
726+ patched_get_manifest .side_effect = (
727+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
716728 )
717729
718730 with pytest .raises (NotImplementedError ):
@@ -730,8 +742,8 @@ def test_get_model_url(
730742
731743 patched_get_model_specs .side_effect = get_prototype_model_spec
732744 patched_validate_model_id_and_get_type .return_value = JumpStartModelType .OPEN_WEIGHTS
733- patched_get_manifest .side_effect = lambda region , * args , ** kwargs : get_prototype_manifest (
734- region
745+ patched_get_manifest .side_effect = (
746+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest ( region , model_type )
735747 )
736748
737749 model_id , version = "xgboost-classification-model" , "1.0.0"
0 commit comments