diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 797d559348..2cc18f6989 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -635,7 +635,6 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): elif self._default_bucket_set_by_sdk: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) - expected_bucket_owner_id = self.account_id() self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) @@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket """ try: - s3.meta.client.head_bucket( - Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id - ) + if self.default_bucket_prefix: + s3.meta.client.list_objects_v2( + Bucket=bucket_name, + Prefix=self.default_bucket_prefix, + ExpectedBucketOwner=expected_bucket_owner_id, + ) + else: + s3.meta.client.head_bucket( + Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id + ) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] @@ -682,7 +688,12 @@ def general_bucket_check_if_user_has_permission( bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not """ try: - s3.meta.client.head_bucket(Bucket=bucket_name) + if self.default_bucket_prefix: + s3.meta.client.list_objects_v2( + Bucket=bucket_name, Prefix=self.default_bucket_prefix + ) + else: + s3.meta.client.head_bucket(Bucket=bucket_name) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 6ce4b50c75..dca1d3dc85 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -39,6 +39,19 @@ def sagemaker_session(): return sagemaker_session +@pytest.fixture() +def sagemaker_session_with_bucket_name_and_prefix(): + boto_mock = MagicMock(name="boto_session", region_name=REGION) + boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} + sagemaker_session = sagemaker.Session( + boto_session=boto_mock, + default_bucket="XXXXXXXXXXXXX", + default_bucket_prefix="sample-prefix", + ) + sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + return sagemaker_session + + def test_default_bucket_s3_create_call(sagemaker_session): error = ClientError( error_response={"Error": {"Code": "404", "Message": "Not Found"}}, @@ -96,6 +109,30 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime assert sagemaker_session._default_bucket is None +def test_default_bucket_with_prefix_s3_needs_bucket_owner_access( + sagemaker_session_with_bucket_name_and_prefix, datetime_obj, caplog +): + with pytest.raises(ClientError): + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource( + "s3" + ).meta.client.list_objects_v2.side_effect = error + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket( + name=DEFAULT_BUCKET_NAME + ).creation_date = None + sagemaker_session_with_bucket_name_and_prefix.default_bucket() + + error_message = "Please try again after adding appropriate access." + assert error_message in caplog.text + assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource( + "s3" + ).meta.client.list_objects_v2.assert_called_once() + + def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog): sagemaker_session._default_bucket_name_override = "custom-bucket-override" error = ClientError(