@@ -418,6 +418,61 @@ def test_select_container_for_mlflow_model_no_dlc_detected(
418418 )
419419
420420
421+ @patch ("sagemaker.image_uris.retrieve" )
422+ @patch ("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version" )
423+ @patch ("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements" )
424+ @patch (
425+ "sagemaker.serve.model_format.mlflow.utils._get_python_version_from_parsed_mlflow_model_file"
426+ )
427+ @patch ("sagemaker.serve.model_format.mlflow.utils._get_all_flavor_metadata" )
428+ @patch ("sagemaker.serve.model_format.mlflow.utils._generate_mlflow_artifact_path" )
429+ def test_select_container_for_mlflow_model_no_framework_version_detected (
430+ mock_generate_mlflow_artifact_path ,
431+ mock_get_all_flavor_metadata ,
432+ mock_get_python_version_from_parsed_mlflow_model_file ,
433+ mock_get_framework_version_from_requirements ,
434+ mock_cast_to_compatible_version ,
435+ mock_image_uris_retrieve ,
436+ ):
437+ mlflow_model_src_path = "/path/to/mlflow_model"
438+ deployment_flavor = "pytorch"
439+ region = "us-west-2"
440+ instance_type = "ml.m5.xlarge"
441+
442+ mock_requirements_path = "/path/to/requirements.txt"
443+ mock_metadata_path = "/path/to/mlmodel"
444+ mock_flavor_metadata = {"pytorch" : {"some_key" : "some_value" }}
445+ mock_python_version = "3.8.6"
446+
447+ mock_generate_mlflow_artifact_path .side_effect = lambda path , artifact : (
448+ mock_requirements_path if artifact == "requirements.txt" else mock_metadata_path
449+ )
450+ mock_get_all_flavor_metadata .return_value = mock_flavor_metadata
451+ mock_get_python_version_from_parsed_mlflow_model_file .return_value = mock_python_version
452+ mock_get_framework_version_from_requirements .return_value = None
453+
454+ with pytest .raises (
455+ ValueError ,
456+ match = "Unable to auto detect framework version. Please provide framework "
457+ "pytorch as part of the requirements.txt file for deployment flavor "
458+ "pytorch" ,
459+ ):
460+ _select_container_for_mlflow_model (
461+ mlflow_model_src_path , deployment_flavor , region , instance_type
462+ )
463+
464+ mock_generate_mlflow_artifact_path .assert_any_call (
465+ mlflow_model_src_path , "requirements.txt"
466+ )
467+ mock_generate_mlflow_artifact_path .assert_any_call (mlflow_model_src_path , "MLmodel" )
468+ mock_get_all_flavor_metadata .assert_called_once_with (mock_metadata_path )
469+ mock_get_framework_version_from_requirements .assert_called_once_with (
470+ deployment_flavor , mock_requirements_path
471+ )
472+ mock_cast_to_compatible_version .assert_not_called ()
473+ mock_image_uris_retrieve .assert_not_called ()
474+
475+
421476def test_validate_input_for_mlflow ():
422477 _validate_input_for_mlflow (ModelServer .TORCHSERVE , "pytorch" )
423478
0 commit comments