Skip to content

Commit 79a1163

Browse files
committed
fix: Minor fixes
1 parent 4bdd822 commit 79a1163

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,9 @@ def validate_model_id_and_get_type(
863863
model_version=model_version,
864864
sagemaker_session=sagemaker_session,
865865
)
866-
return model_types[0] # Currently this function only supports one model type
866+
return (
867+
model_types[0] if model_types else None
868+
) # Currently this function only supports one model type
867869

868870
s3_client = sagemaker_session.s3_client if sagemaker_session else None
869871
region = region or constants.JUMPSTART_DEFAULT_REGION_NAME
@@ -908,8 +910,8 @@ def _validate_hub_service_model_id_and_get_type(
908910
)
909911

910912
hub_content_model_types = []
911-
model_types_field = getattr(hub_model_specs, "model_types", [])
912-
model_types = model_types_field if model_types_field is not None else []
913+
model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", [])
914+
model_types = model_types_field if model_types_field else []
913915
for model_type in model_types:
914916
try:
915917
hub_content_model_types.append(enums.JumpStartModelType[model_type])

tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session):
9898

9999

100100
@pytest.mark.parametrize(
101-
"get_model_specs_response, expected, expected_exception, expected_message",
101+
"get_model_specs_attr, get_model_specs_response, expected, expected_exception, expected_message",
102102
[
103-
(None, [], None, None),
104-
([], [], None, None),
105-
(["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None),
103+
(False, None, [], None, None),
104+
(True, None, [], None, None),
105+
(True, [], [], None, None),
106+
(True, ["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None),
106107
(
108+
True,
107109
["OPEN_WEIGHTS", "PROPRIETARY"],
108110
[JumpStartModelType.OPEN_WEIGHTS, JumpStartModelType.PROPRIETARY],
109111
None,
@@ -113,10 +115,15 @@ def _test_proprietary_model(input_version, expected_version, sagemaker_session):
113115
)
114116
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
115117
def test_validate_hub_service_model_id_and_get_type(
116-
mock_get_model_specs, get_model_specs_response, expected, expected_exception, expected_message
118+
mock_get_model_specs,
119+
get_model_specs_attr,
120+
get_model_specs_response,
121+
expected,
122+
expected_exception,
123+
expected_message,
117124
):
118125
mock_object = MagicMock()
119-
if get_model_specs_response:
126+
if get_model_specs_attr:
120127
mock_object.model_types = get_model_specs_response
121128
mock_get_model_specs.return_value = mock_object
122129

0 commit comments

Comments
 (0)