Skip to content

Commit a1574a2

Browse files
committed
chore: emit log when legacy fields used to get jumpstart image uri
1 parent 4b24784 commit a1574a2

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/sagemaker/image_uris.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from packaging.version import Version
2222

2323
from sagemaker import utils
24-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
24+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
2525
from sagemaker.jumpstart.enums import JumpStartModelType
2626
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2727
from sagemaker.spark import defaults
@@ -154,6 +154,18 @@ def retrieve(
154154
)
155155

156156
if is_jumpstart_model_input(model_id, model_version):
157+
if non_none_fields := {
158+
key: value
159+
for key, value in args.items()
160+
if key in {"version", "framework", "container_version", "py_version"}
161+
and value is not None
162+
}:
163+
JUMPSTART_LOGGER.info(
164+
"Ignoring the following fields when retriving image uri "
165+
"for JumpStart model id '%s': %s",
166+
model_id,
167+
str(non_none_fields),
168+
)
157169
return artifacts._retrieve_image_uri(
158170
model_id=model_id,
159171
model_version=model_version,

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,54 @@ def test_jumpstart_common_image_uri(
186186
model_id="pytorch-ic-mobilenet-v2",
187187
instance_type="ml.m5.xlarge",
188188
)
189+
190+
191+
@patch("sagemaker.image_uris.JUMPSTART_LOGGER.info")
192+
@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
193+
@patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs")
194+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
195+
def test_jumpstart_image_uri_logging_extra_fields(
196+
patched_get_model_specs,
197+
patched_verify_model_region_and_return_specs,
198+
patched_validate_model_id_and_get_type,
199+
patched_info_log,
200+
):
201+
202+
patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
203+
patched_get_model_specs.side_effect = get_spec_from_base_spec
204+
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
205+
206+
region = "us-west-2"
207+
mock_client = boto3.client("s3")
208+
mock_session = Mock(s3_client=mock_client, boto_region_name=region)
209+
210+
image_uris.retrieve(
211+
framework=None,
212+
region="us-west-2",
213+
image_scope="training",
214+
model_id="pytorch-ic-mobilenet-v2",
215+
model_version="*",
216+
instance_type="ml.m5.xlarge",
217+
sagemaker_session=mock_session,
218+
)
219+
220+
patched_info_log.assert_not_called()
221+
222+
image_uris.retrieve(
223+
framework="framework",
224+
container_version="1.2.3",
225+
region="us-west-2",
226+
image_scope="training",
227+
model_id="pytorch-ic-mobilenet-v2",
228+
model_version="*",
229+
instance_type="ml.m5.xlarge",
230+
sagemaker_session=mock_session,
231+
)
232+
233+
patched_info_log.assert_called_once_with(
234+
"Ignoring the following fields "
235+
"when retriving image uri for "
236+
"JumpStart model id '%s': %s",
237+
"pytorch-ic-mobilenet-v2",
238+
"{'framework': 'framework', 'container_version': '1.2.3'}",
239+
)

0 commit comments

Comments
 (0)