diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 04a7326557..7fb4c9b964 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -630,14 +630,13 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): s3 = self.s3_resource bucket = s3.Bucket(name=bucket_name) + if bucket.creation_date is None: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) - + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, self.account_id()) elif self._default_bucket_set_by_sdk: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) - - expected_bucket_owner_id = self.account_id() - self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, self.account_id()) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): """Checks if the bucket belongs to a particular owner and throws a Client Error if it is not diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 6ce4b50c75..cdc1248367 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -173,3 +173,31 @@ def test_bucket_creation_other_error(sagemaker_session): sagemaker_session.default_bucket() assert sagemaker_session._default_bucket is None + + +def test_default_bucket_s3_create_call_creation_date(sagemaker_session): + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock( + side_effect=error + ) + + with pytest.raises(ClientError): + sagemaker_session.default_bucket() + + +def test_default_bucket_s3_create_call_default_bucket_set_by_sdk(sagemaker_session): + sagemaker_session._default_bucket_set_by_sdk = True + sagemaker_session.boto_session.resource("s3").Bucket().creation_date = 1733509801 + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock( + side_effect=error + ) + + with pytest.raises(ClientError): + sagemaker_session.default_bucket() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d2d2c3bcfb..451a2bbc41 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -102,6 +102,7 @@ def boto_session(request): with patch("sagemaker.user_agent.get_user_agent_extra_suffix", return_value=user_agent_suffix): client_mock._client_config.user_agent = user_agent boto_mock.client.return_value = client_mock + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "Account-001"} return boto_mock