diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index c6a2014ae5..b2398e03d1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4347,11 +4347,59 @@ def submit(request): if model_package_group_name is not None and not model_package_group_name.startswith( "arn:" ): - _create_resource( - lambda: self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] + is_model_package_group_present = False + try: + model_package_groups_response = self.search( + resource="ModelPackageGroup", + search_expression={ + "Filters": [ + { + "Name": "ModelPackageGroupName", + "Value": request["ModelPackageGroupName"], + "Operator": "Equals", + } + ], + }, + ) + if len(model_package_groups_response.get("Results")) > 0: + is_model_package_group_present = True + except Exception: # pylint: disable=W0703 + model_package_groups = [] + model_package_groups_response = self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + while next_token is not None and next_token != "": + model_package_groups_response = ( + self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], NextToken=next_token + ) + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + filtered_model_package_group = list( + filter( + lambda mpg: mpg.get("ModelPackageGroupName") + == request["ModelPackageGroupName"], + model_package_groups, + ) + ) + is_model_package_group_present = len(filtered_model_package_group) > 0 + if not is_model_package_group_present: + _create_resource( + lambda: self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) ) - ) if "SourceUri" in request and request["SourceUri"] is not None: # Remove inference spec from request if the # given source uri can lead to auto-population of it diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d2d2c3bcfb..f873e9b14c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5006,6 +5006,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5094,6 +5095,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + created_versioned_mp_arn = ( "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" ) @@ -5149,6 +5152,7 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp approval_status = ("Approved",) skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} with pytest.raises( ValueError, @@ -5221,6 +5225,8 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake return_value={"ModelPackageArn": created_versioned_mp_arn} ) + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( model_package_group_name=model_package_group_name, containers=containers, @@ -5443,6 +5449,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5510,6 +5517,7 @@ def test_create_model_package_from_containers_with_one_instance_types( approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -7183,3 +7191,65 @@ def test_delete_hub_content_reference(sagemaker_session): } sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search( + sagemaker_session, +): + sagemaker_session.sagemaker_client.search.side_effect = Exception() + sagemaker_session.sagemaker_client.search.return_value = {} + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + { + "ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}], + "NextToken": "NextToken", + }, + {"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]}, + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + {"ModelPackageGroupSummaryList": []} + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session): + # with search api + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "ModelPackageGroup": { + "ModelPackageGroupName": "mock-mpg", + "ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg", + } + } + ] + } + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + )