@@ -98,12 +98,14 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session):
9898
9999
100100@pytest .mark .parametrize (
101- "get_model_specs_response, expected, expected_exception, expected_message" ,
101+ "get_model_specs_attr, get_model_specs_response, expected, expected_exception, expected_message" ,
102102 [
103- (None , [], None , None ),
104- ([], [], None , None ),
105- (["OPEN_WEIGHTS" ], [JumpStartModelType .OPEN_WEIGHTS ], None , None ),
103+ (False , None , [], None , None ),
104+ (True , None , [], None , None ),
105+ (True , [], [], None , None ),
106+ (True , ["OPEN_WEIGHTS" ], [JumpStartModelType .OPEN_WEIGHTS ], None , None ),
106107 (
108+ True ,
107109 ["OPEN_WEIGHTS" , "PROPRIETARY" ],
108110 [JumpStartModelType .OPEN_WEIGHTS , JumpStartModelType .PROPRIETARY ],
109111 None ,
@@ -113,10 +115,15 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session):
113115)
114116@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
115117def test_validate_hub_service_model_id_and_get_type (
116- mock_get_model_specs , get_model_specs_response , expected , expected_exception , expected_message
118+ mock_get_model_specs ,
119+ get_model_specs_attr ,
120+ get_model_specs_response ,
121+ expected ,
122+ expected_exception ,
123+ expected_message ,
117124):
118125 mock_object = MagicMock ()
119- if get_model_specs_response :
126+ if get_model_specs_attr :
120127 mock_object .model_types = get_model_specs_response
121128 mock_get_model_specs .return_value = mock_object
122129
0 commit comments