@@ -887,6 +887,49 @@ def make_deprecated_spec(*largs, **kwargs):
887887 "*" ,
888888 )
889889
890+ deprecated_message = "this model is deprecated"
891+
892+ def make_deprecated_message_spec (* largs , ** kwargs ):
893+ spec = get_spec_from_base_spec (* largs , ** kwargs )
894+ spec .deprecated_message = deprecated_message
895+ spec .deprecated = True
896+ return spec
897+
898+ patched_get_model_specs .side_effect = make_deprecated_message_spec
899+
900+ with pytest .raises (DeprecatedJumpStartModelError ) as e :
901+ utils .verify_model_region_and_return_specs (
902+ model_id = "pytorch-eqa-bert-base-cased" ,
903+ version = "*" ,
904+ scope = JumpStartScriptScope .INFERENCE .value ,
905+ region = "us-west-2" ,
906+ )
907+ assert deprecated_message == str (e .value .message )
908+
909+ deprecate_warn_message = "warn-msg"
910+
911+ def make_deprecated_warning_message_spec (* largs , ** kwargs ):
912+ spec = get_spec_from_base_spec (* largs , ** kwargs )
913+ spec .deprecate_warn_message = deprecate_warn_message
914+ return spec
915+
916+ patched_get_model_specs .side_effect = make_deprecated_warning_message_spec
917+
918+ with patch ("logging.Logger.warning" ) as mocked_warning_log :
919+ assert (
920+ utils .verify_model_region_and_return_specs (
921+ model_id = "pytorch-eqa-bert-base-cased" ,
922+ version = "*" ,
923+ scope = JumpStartScriptScope .INFERENCE .value ,
924+ region = "us-west-2" ,
925+ tolerate_deprecated_model = True ,
926+ )
927+ is not None
928+ )
929+ mocked_warning_log .assert_called_once_with (
930+ deprecate_warn_message ,
931+ )
932+
890933
891934def test_get_jumpstart_base_name_if_jumpstart_model ():
892935 uris = [random_jumpstart_s3_uri ("random_key" ) for _ in range (random .randint (1 , 10 ))]
0 commit comments