Skip to content
Merged
33 changes: 24 additions & 9 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,16 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
s3 = self.s3_resource

bucket = s3.Bucket(name=bucket_name)
expected_bucket_owner_id = self.account_id()
if bucket.creation_date is None:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
self.general_bucket_check_if_user_has_permission(
bucket_name, s3, bucket, region, True, expected_bucket_owner_id
)

elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)

expected_bucket_owner_id = self.account_id()
self.general_bucket_check_if_user_has_permission(
bucket_name, s3, bucket, region, False, expected_bucket_owner_id
)
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)

def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
Expand All @@ -649,9 +652,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket

"""
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id,
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand All @@ -668,7 +678,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
raise

def general_bucket_check_if_user_has_permission(
self, bucket_name, s3, bucket, region, bucket_creation_date_none
self, bucket_name, s3, bucket, region, bucket_creation_date_none, expected_bucket_owner_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this extra parameter is unused

):
"""Checks if the person running has the permissions to the bucket

Expand All @@ -682,7 +692,12 @@ def general_bucket_check_if_user_has_permission(
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
s3.meta.client.head_bucket(Bucket=bucket_name)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name, Prefix=self.default_bucket_prefix
)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def sagemaker_session():
return sagemaker_session


@pytest.fixture()
def sagemaker_session_with_bucket_name_and_prefix():
boto_mock = MagicMock(name="boto_session", region_name=REGION)
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
sagemaker_session = sagemaker.Session(
boto_session=boto_mock,
default_bucket="XXXXXXXXXXXXX",
default_bucket_prefix="sample-prefix",
)
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
return sagemaker_session


def test_default_bucket_s3_create_call(sagemaker_session):
error = ClientError(
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
Expand Down Expand Up @@ -96,6 +109,30 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime
assert sagemaker_session._default_bucket is None


def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(
sagemaker_session_with_bucket_name_and_prefix, datetime_obj, caplog
):
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource(
"s3"
).meta.client.list_objects_v2.side_effect = error
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket(
name=DEFAULT_BUCKET_NAME
).creation_date = None
sagemaker_session_with_bucket_name_and_prefix.default_bucket()

error_message = "Please try again after adding appropriate access."
assert error_message in caplog.text
assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource(
"s3"
).meta.client.list_objects_v2.assert_called_once()


def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
sagemaker_session._default_bucket_name_override = "custom-bucket-override"
error = ClientError(
Expand Down
Loading