Skip to content

Commit e3e51b9

Browse files
author
Keshav Chandak
committed
bugix: Added check for the presence of model package group before creating one
1 parent e68105c commit e3e51b9

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

src/sagemaker/session.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4347,7 +4347,29 @@ def submit(request):
43474347
if model_package_group_name is not None and not model_package_group_name.startswith(
43484348
"arn:"
43494349
):
4350-
_create_resource(
4350+
model_package_groups = []
4351+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4352+
NameContains=request["ModelPackageGroupName"],
4353+
)
4354+
model_package_groups = model_package_groups + model_package_groups_response["ModelPackageGroupSummaryList"]
4355+
next_token = model_package_groups_response.get("NextToken")
4356+
4357+
while next_token is not None and next_token != "":
4358+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4359+
NameContains=request["ModelPackageGroupName"],
4360+
NextToken=next_token
4361+
)
4362+
model_package_groups = model_package_groups + model_package_groups_response["ModelPackageGroupSummaryList"]
4363+
next_token = model_package_groups_response.get("NextToken")
4364+
4365+
filtered_model_package_group = list(
4366+
filter(
4367+
lambda mpg: mpg.get("ModelPackageGroupName") == request["ModelPackageGroupName"],
4368+
model_package_groups
4369+
)
4370+
)
4371+
if not filtered_model_package_group:
4372+
_create_resource(
43514373
lambda: self.sagemaker_client.create_model_package_group(
43524374
ModelPackageGroupName=request["ModelPackageGroupName"]
43534375
)

tests/unit/test_chainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050

5151
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
5252

53-
5453
@pytest.fixture()
5554
def sagemaker_session():
5655
boto_mock = Mock(name="boto_session", region_name=REGION)

tests/unit/test_session.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,9 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session)
50065006
domain = "COMPUTER_VISION"
50075007
task = "IMAGE_CLASSIFICATION"
50085008
sample_payload_url = "s3://test-bucket/model"
5009+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5010+
"ModelPackageGroupSummaryList": []
5011+
}
50095012
sagemaker_session.create_model_package_from_containers(
50105013
containers=containers,
50115014
content_types=content_types,
@@ -5093,6 +5096,10 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec
50935096
approval_status = ("Approved",)
50945097
skip_model_validation = "All"
50955098
source_uri = "dummy-source-uri"
5099+
5100+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5101+
"ModelPackageGroupSummaryList": []
5102+
}
50965103

50975104
created_versioned_mp_arn = (
50985105
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
@@ -5149,6 +5156,9 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp
51495156
approval_status = ("Approved",)
51505157
skip_model_validation = "All"
51515158
source_uri = "dummy-source-uri"
5159+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5160+
"ModelPackageGroupSummaryList": []
5161+
}
51525162

51535163
with pytest.raises(
51545164
ValueError,
@@ -5220,6 +5230,10 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake
52205230
sagemaker_session.sagemaker_client.create_model_package = Mock(
52215231
return_value={"ModelPackageArn": created_versioned_mp_arn}
52225232
)
5233+
5234+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5235+
"ModelPackageGroupSummaryList": []
5236+
}
52235237

52245238
sagemaker_session.create_model_package_from_containers(
52255239
model_package_group_name=model_package_group_name,
@@ -5443,6 +5457,9 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
54435457
approval_status = ("Approved",)
54445458
description = "description"
54455459
customer_metadata_properties = {"key1": "value1"}
5460+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5461+
"ModelPackageGroupSummaryList": []
5462+
}
54465463
sagemaker_session.create_model_package_from_containers(
54475464
containers=containers,
54485465
content_types=content_types,
@@ -5510,6 +5527,9 @@ def test_create_model_package_from_containers_with_one_instance_types(
55105527
approval_status = ("Approved",)
55115528
description = "description"
55125529
customer_metadata_properties = {"key1": "value1"}
5530+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5531+
"ModelPackageGroupSummaryList": []
5532+
}
55135533
sagemaker_session.create_model_package_from_containers(
55145534
containers=containers,
55155535
content_types=content_types,
@@ -7183,3 +7203,21 @@ def test_delete_hub_content_reference(sagemaker_session):
71837203
}
71847204

71857205
sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request)
7206+
7207+
def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session):
7208+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
7209+
"ModelPackageGroupSummaryList": [{
7210+
"ModelPackageGroupName": "mock-mpg"
7211+
}]
7212+
}
7213+
sagemaker_session.create_model_package_from_containers(source_uri="mock-source-uri", model_package_group_name="mock-mpg")
7214+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7215+
sagemaker_session.create_model_package_from_containers(source_uri="mock-source-uri", model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg")
7216+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7217+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
7218+
"ModelPackageGroupSummaryList": []
7219+
}
7220+
sagemaker_session.create_model_package_from_containers(source_uri="mock-source-uri", model_package_group_name="mock-mpg")
7221+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(ModelPackageGroupName="mock-mpg")
7222+
7223+

0 commit comments

Comments
 (0)