File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed
src/sagemaker/jumpstart/factory
tests/unit/sagemaker/jumpstart/estimator Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -322,7 +322,12 @@ def get_deploy_kwargs(
322322 model_id = model_id ,
323323 model_from_estimator = True ,
324324 model_version = model_version ,
325- instance_type = model_deploy_kwargs .instance_type if training_instance_type is None else None ,
325+ instance_type = (
326+ model_deploy_kwargs .instance_type
327+ if training_instance_type is None
328+ or instance_type is not None # always use supplied inference instance type
329+ else None
330+ ),
326331 region = region ,
327332 image_uri = image_uri ,
328333 source_dir = source_dir ,
Original file line number Diff line number Diff line change @@ -1532,6 +1532,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta
15321532 estimator .deploy (image_uri = "blah" )
15331533 assert mock_estimator_deploy .call_args [1 ]["instance_type" ] == "ml.p4de.24xlarge"
15341534
1535+ estimator .deploy (image_uri = "blah" , instance_type = "ml.quantum.large" )
1536+ assert mock_estimator_deploy .call_args [1 ]["instance_type" ] == "ml.quantum.large"
1537+
15351538 @mock .patch ("sagemaker.utils.sagemaker_timestamp" )
15361539 @mock .patch ("sagemaker.jumpstart.estimator.validate_model_id_and_get_type" )
15371540 @mock .patch (
You can’t perform that action at this time.
0 commit comments