From 4fb0bd303bbf759b804c0ed2a9266b0fda0933de Mon Sep 17 00:00:00 2001 From: Gokul A <166456257+nargokul@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:34:45 -0800 Subject: [PATCH 1/4] Sniping fix for S3 bucket --- src/sagemaker/session.py | 6 +++--- tests/unit/test_default_bucket.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 04a7326557..06c13a4bf7 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -630,13 +630,13 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): s3 = self.s3_resource bucket = s3.Bucket(name=bucket_name) + expected_bucket_owner_id = self.account_id() + if bucket.creation_date is None: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) - + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) elif self._default_bucket_set_by_sdk: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) - - expected_bucket_owner_id = self.account_id() self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 6ce4b50c75..4a00341e7d 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import datetime +import unittest from unittest.mock import Mock import pytest @@ -173,3 +174,31 @@ def test_bucket_creation_other_error(sagemaker_session): sagemaker_session.default_bucket() assert sagemaker_session._default_bucket is None + + +def test_default_bucket_s3_create_call_creation_date(sagemaker_session): + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock( + side_effect=error + ) + + with pytest.raises(ClientError): + sagemaker_session.default_bucket() + + +def test_default_bucket_s3_create_call_default_bucket_set_by_sdk(sagemaker_session): + sagemaker_session._default_bucket_set_by_sdk = True + sagemaker_session.boto_session.resource("s3").Bucket().creation_date = 1733509801 + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock( + side_effect=error + ) + + with pytest.raises(ClientError): + sagemaker_session.default_bucket() From 4d6ba29e086eebe84a1aeeaa7aec5a06577bf8c3 Mon Sep 17 00:00:00 2001 From: Gokul A <166456257+nargokul@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:05:18 -0800 Subject: [PATCH 2/4] Test fix --- tests/unit/test_session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d2d2c3bcfb..451a2bbc41 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -102,6 +102,7 @@ def boto_session(request): with patch("sagemaker.user_agent.get_user_agent_extra_suffix", return_value=user_agent_suffix): client_mock._client_config.user_agent = user_agent boto_mock.client.return_value = client_mock + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "Account-001"} return boto_mock From 1bd79270be9e41010648ca68b87741955adcec26 Mon Sep 17 00:00:00 2001 From: Gokul A <166456257+nargokul@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:44:20 -0800 Subject: [PATCH 3/4] Codestyle --- tests/unit/test_default_bucket.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 4a00341e7d..cdc1248367 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -13,7 +13,6 @@ from __future__ import absolute_import import datetime -import unittest from unittest.mock import Mock import pytest From 13699c3b7808df5f37f51e8ea5e807b47e6db89c Mon Sep 17 00:00:00 2001 From: Gokul A <166456257+nargokul@users.noreply.github.com> Date: Fri, 6 Dec 2024 19:03:32 -0800 Subject: [PATCH 4/4] Fix --- src/sagemaker/session.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 06c13a4bf7..7fb4c9b964 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -630,14 +630,13 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): s3 = self.s3_resource bucket = s3.Bucket(name=bucket_name) - expected_bucket_owner_id = self.account_id() if bucket.creation_date is None: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) - self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, self.account_id()) elif self._default_bucket_set_by_sdk: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) - self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, self.account_id()) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): """Checks if the bucket belongs to a particular owner and throws a Client Error if it is not