Skip to content
Merged
31 changes: 22 additions & 9 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
s3 = self.s3_resource

bucket = s3.Bucket(name=bucket_name)
expected_bucket_owner_id = self.account_id()
if bucket.creation_date is None:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True, expected_bucket_owner_id)

elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)

expected_bucket_owner_id = self.account_id()
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False, expected_bucket_owner_id)
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)

def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
Expand All @@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket

"""
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand All @@ -668,7 +674,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
raise

def general_bucket_check_if_user_has_permission(
self, bucket_name, s3, bucket, region, bucket_creation_date_none
self, bucket_name, s3, bucket, region, bucket_creation_date_none, expected_bucket_owner_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this extra parameter is unused

):
"""Checks if the person running has the permissions to the bucket

Expand All @@ -682,7 +688,14 @@ def general_bucket_check_if_user_has_permission(
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
s3.meta.client.head_bucket(Bucket=bucket_name)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id
Copy link
Contributor

@benieric benieric Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this check is a bit different in this method. Looks like it just checks for general permission to the bucket rather than checking for ExpectedBucketOwner. For this case, probably need to remove the ExpectedBucketOwner=expected_bucket_owner_id to match the behavior of the second block which just does a regular head bucket check - s3.meta.client.head_bucket(Bucket=bucket_name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good callout. Updated the PR.

)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def sagemaker_session():
return sagemaker_session


@pytest.fixture()
def sagemaker_session_with_bucket_name_and_prefix():
boto_mock = MagicMock(name="boto_session", region_name=REGION)
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
sagemaker_session = sagemaker.Session(boto_session=boto_mock,
default_bucket="XXXXXXXXXXXXX",
default_bucket_prefix="sample-prefix")
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
return sagemaker_session


def test_default_bucket_s3_create_call(sagemaker_session):
error = ClientError(
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
Expand Down Expand Up @@ -95,6 +106,24 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime
assert error_message in caplog.text
assert sagemaker_session._default_bucket is None

def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(sagemaker_session_with_bucket_name_and_prefix,
datetime_obj,
caplog):
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.side_effect = error
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket(
name=DEFAULT_BUCKET_NAME
).creation_date = None
sagemaker_session_with_bucket_name_and_prefix.default_bucket()

error_message = "Please try again after adding appropriate access."
assert error_message in caplog.text
assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.assert_called_once()

def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
sagemaker_session._default_bucket_name_override = "custom-bucket-override"
Expand Down
Loading