@@ -98,12 +98,14 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session):
98
98
99
99
100
100
@pytest .mark .parametrize (
101
- "get_model_specs_response, expected, expected_exception, expected_message" ,
101
+ "get_model_specs_attr, get_model_specs_response, expected, expected_exception, expected_message" ,
102
102
[
103
- (None , [], None , None ),
104
- ([], [], None , None ),
105
- (["OPEN_WEIGHTS" ], [JumpStartModelType .OPEN_WEIGHTS ], None , None ),
103
+ (False , None , [], None , None ),
104
+ (True , None , [], None , None ),
105
+ (True , [], [], None , None ),
106
+ (True , ["OPEN_WEIGHTS" ], [JumpStartModelType .OPEN_WEIGHTS ], None , None ),
106
107
(
108
+ True ,
107
109
["OPEN_WEIGHTS" , "PROPRIETARY" ],
108
110
[JumpStartModelType .OPEN_WEIGHTS , JumpStartModelType .PROPRIETARY ],
109
111
None ,
@@ -113,10 +115,15 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session):
113
115
)
114
116
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
115
117
def test_validate_hub_service_model_id_and_get_type (
116
- mock_get_model_specs , get_model_specs_response , expected , expected_exception , expected_message
118
+ mock_get_model_specs ,
119
+ get_model_specs_attr ,
120
+ get_model_specs_response ,
121
+ expected ,
122
+ expected_exception ,
123
+ expected_message ,
117
124
):
118
125
mock_object = MagicMock ()
119
- if get_model_specs_response :
126
+ if get_model_specs_attr :
120
127
mock_object .model_types = get_model_specs_response
121
128
mock_get_model_specs .return_value = mock_object
122
129
0 commit comments