Skip to content
14 changes: 13 additions & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from packaging.version import Version

from sagemaker import utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
Expand Down Expand Up @@ -154,6 +154,18 @@ def retrieve(
)

if is_jumpstart_model_input(model_id, model_version):
if non_none_fields := {
key: value
for key, value in args.items()
if key in {"version", "framework", "container_version", "py_version"}
and value is not None
}:
JUMPSTART_LOGGER.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - I would vote for a warning but your call :)

technically, this is still backward incompatible, but the risk of breaking any customers is really low (it would require cx to have written code than calls the retrieve method w/ both a model_id and a framework, container_version and/or py_version which is unlikely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was thinking of warning but thought it might spam the logs.

i don't see how we can make this not backwards. we can no longer rely on ecr specs for raising exceptions.

"Ignoring the following arguments when retrieving image uri "
"for JumpStart model id '%s': %s",
model_id,
str(non_none_fields),
)
return artifacts._retrieve_image_uri(
model_id=model_id,
model_version=model_version,
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/sagemaker/image_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,54 @@ def test_jumpstart_common_image_uri(
model_id="pytorch-ic-mobilenet-v2",
instance_type="ml.m5.xlarge",
)


@patch("sagemaker.image_uris.JUMPSTART_LOGGER.info")
@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
@patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_image_uri_logging_extra_fields(
patched_get_model_specs,
patched_verify_model_region_and_return_specs,
patched_validate_model_id_and_get_type,
patched_info_log,
):

patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
patched_get_model_specs.side_effect = get_spec_from_base_spec
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS

region = "us-west-2"
mock_client = boto3.client("s3")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: that's not a mock, is it?

mock_session = Mock(s3_client=mock_client, boto_region_name=region)

image_uris.retrieve(
framework=None,
region="us-west-2",
image_scope="training",
model_id="pytorch-ic-mobilenet-v2",
model_version="*",
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

patched_info_log.assert_not_called()

image_uris.retrieve(
framework="framework",
container_version="1.2.3",
region="us-west-2",
image_scope="training",
model_id="pytorch-ic-mobilenet-v2",
model_version="*",
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

patched_info_log.assert_called_once_with(
"Ignoring the following arguments "
"when retrieving image uri for "
"JumpStart model id '%s': %s",
"pytorch-ic-mobilenet-v2",
"{'framework': 'framework', 'container_version': '1.2.3'}",
)
Loading