|
18 | 18 | from inspect import signature |
19 | 19 |
|
20 | 20 | import pytest |
| 21 | +from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig |
21 | 22 |
|
22 | 23 | from sagemaker.debugger.profiler_config import ProfilerConfig |
23 | 24 | from sagemaker.estimator import Estimator |
@@ -640,6 +641,56 @@ def test_no_predictor_returns_default_predictor( |
640 | 641 | self.assertEqual(type(predictor), Predictor) |
641 | 642 | self.assertEqual(predictor, default_predictor_with_presets) |
642 | 643 |
|
| 644 | + @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") |
| 645 | + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") |
| 646 | + @mock.patch("sagemaker.jumpstart.factory.model.Session") |
| 647 | + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") |
| 648 | + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 649 | + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") |
| 650 | + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") |
| 651 | + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") |
| 652 | + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) |
| 653 | + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) |
| 654 | + def test_no_predictor_yes_async_inference_config( |
| 655 | + self, |
| 656 | + mock_estimator_deploy: mock.Mock, |
| 657 | + mock_estimator_fit: mock.Mock, |
| 658 | + mock_estimator_init: mock.Mock, |
| 659 | + mock_get_model_specs: mock.Mock, |
| 660 | + mock_session_estimator: mock.Mock, |
| 661 | + mock_session_model: mock.Mock, |
| 662 | + mock_is_valid_model_id: mock.Mock, |
| 663 | + mock_get_default_predictor: mock.Mock, |
| 664 | + ): |
| 665 | + mock_estimator_deploy.return_value = default_predictor |
| 666 | + |
| 667 | + mock_get_default_predictor.return_value = default_predictor_with_presets |
| 668 | + |
| 669 | + mock_is_valid_model_id.return_value = True |
| 670 | + |
| 671 | + model_id, _ = "js-trainable-model-prepacked", "*" |
| 672 | + |
| 673 | + mock_get_model_specs.side_effect = get_special_model_spec |
| 674 | + |
| 675 | + mock_session_estimator.return_value = sagemaker_session |
| 676 | + mock_session_model.return_value = sagemaker_session |
| 677 | + |
| 678 | + estimator = JumpStartEstimator( |
| 679 | + model_id=model_id, |
| 680 | + ) |
| 681 | + |
| 682 | + channels = { |
| 683 | + "training": f"s3://{get_jumpstart_content_bucket(region)}/" |
| 684 | + f"some-training-dataset-doesn't-matter", |
| 685 | + } |
| 686 | + |
| 687 | + estimator.fit(channels) |
| 688 | + |
| 689 | + predictor = estimator.deploy(async_inference_config=AsyncInferenceConfig()) |
| 690 | + |
| 691 | + mock_get_default_predictor.assert_not_called() |
| 692 | + self.assertEqual(type(predictor), Predictor) |
| 693 | + |
643 | 694 | @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") |
644 | 695 | @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") |
645 | 696 | @mock.patch("sagemaker.jumpstart.factory.model.Session") |
|
0 commit comments