Skip to content

Commit edfb9db

Browse files
committed
feature: add S3Downloader.list(s3_uri) functionality (#283)
Adds the ability to call S3Downloader.list(base_s3_uri) on a base S3 path, which returns a list of S3 URIs corresponding to valid S3 objects under that path.
1 parent 8f01c40 commit edfb9db

File tree

3 files changed

+110
-5
lines changed

3 files changed

+110
-5
lines changed

src/sagemaker/s3.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""This module contains Enums and helper methods related to S3."""
1414
from __future__ import print_function, absolute_import
1515

16+
import os
17+
1618
from six.moves.urllib.parse import urlparse
1719
from sagemaker.session import Session
1820

@@ -42,7 +44,7 @@ def upload(local_path, desired_s3_uri, kms_key=None, session=None):
4244
Args:
4345
local_path (str): A local path to a file or directory.
4446
desired_s3_uri (str): The desired S3 uri to upload to.
45-
kms_key (str): A KMS key to be provided as an extra argument.
47+
kms_key (str): The KMS key to use to encrypt the files.
4648
session (sagemaker.session.Session): Session object which
4749
manages interactions with Amazon SageMaker APIs and any other
4850
AWS services needed. If not specified, the estimator creates one
@@ -70,7 +72,7 @@ def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, session=
7072
Args:
7173
body (str): String representing the body of the file.
7274
desired_s3_uri (str): The desired S3 uri to upload to.
73-
kms_key (str): A KMS key to be provided as an extra argument.
75+
kms_key (str): The KMS key to use to encrypt the files.
7476
session (sagemaker.session.Session): AWS session to use. Automatically
7577
generates one if not provided.
7678
@@ -98,7 +100,7 @@ def download(s3_uri, local_path, kms_key=None, session=None):
98100
Args:
99101
s3_uri (str): An S3 uri to download from.
100102
local_path (str): A local path to download the file(s) to.
101-
kms_key (str): A KMS key to be provided as an extra argument.
103+
kms_key (str): The KMS key to use to decrypt the files.
102104
session (sagemaker.session.Session): Session object which
103105
manages interactions with Amazon SageMaker APIs and any other
104106
AWS services needed. If not specified, the estimator creates one
@@ -133,3 +135,22 @@ def read_file(s3_uri, session=None):
133135
bucket, key_prefix = parse_s3_url(url=s3_uri)
134136

135137
return sagemaker_session.read_s3_file(bucket=bucket, key_prefix=key_prefix)
138+
139+
@staticmethod
140+
def list(s3_uri, session=None):
141+
"""Static method that lists the contents of an S3 uri.
142+
143+
Args:
144+
s3_uri (str): The S3 base uri to list objects in.
145+
session (sagemaker.session.Session): AWS session to use. Automatically
146+
generates one if not provided.
147+
148+
Returns:
149+
[str]: The list of S3 URIs in the given S3 base uri.
150+
151+
"""
152+
sagemaker_session = session or Session()
153+
bucket, key_prefix = parse_s3_url(url=s3_uri)
154+
155+
file_keys = sagemaker_session.list_s3_files(bucket=bucket, key_prefix=key_prefix)
156+
return [os.path.join("s3://", bucket, file_key) for file_key in file_keys]

src/sagemaker/session.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
203203
default bucket of the ``Session`` is used (if default bucket does not exist, the
204204
``Session`` creates it).
205205
key (str): S3 object key. This is the s3 path to the file.
206-
kms_key (str): The KMS key to use for decrypting the file.
206+
kms_key (str): The KMS key to use for encrypting the file.
207207
208208
Returns:
209209
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
@@ -291,7 +291,24 @@ def read_s3_file(self, bucket, key_prefix):
291291
# Explicitly passing a None kms_key to boto3 throws a validation error.
292292
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
293293

294-
return s3_object["Body"].read()
294+
return s3_object["Body"].read().decode("utf-8")
295+
296+
def list_s3_files(self, bucket, key_prefix):
297+
"""Lists the S3 files given an S3 bucket and key.
298+
299+
Args:
300+
bucket (str): Name of the S3 Bucket to download from.
301+
key_prefix (str): S3 object key name prefix.
302+
303+
Returns:
304+
[str]: The list of files at the S3 path.
305+
306+
"""
307+
s3 = self.boto_session.resource("s3")
308+
309+
s3_bucket = s3.Bucket(name=bucket)
310+
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
311+
return [s3_object.key for s3_object in s3_objects]
295312

296313
def default_bucket(self):
297314
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.

tests/integ/test_s3.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import uuid
17+
18+
import pytest
19+
20+
from sagemaker.s3 import S3Uploader
21+
from sagemaker.s3 import S3Downloader
22+
23+
from tests.integ.kms_utils import get_or_create_kms_key
24+
25+
26+
@pytest.fixture(scope="module")
27+
def s3_files_kms_key(sagemaker_session):
28+
return get_or_create_kms_key(sagemaker_session=sagemaker_session)
29+
30+
31+
def test_statistics_object_creation_from_s3_uri_with_customizations(
32+
sagemaker_session, s3_files_kms_key
33+
):
34+
file_1_body = "First File Body."
35+
file_1_name = "first_file.txt"
36+
file_2_body = "Second File Body."
37+
file_2_name = "second_file.txt"
38+
39+
my_uuid = str(uuid.uuid4())
40+
41+
base_s3_uri = os.path.join(
42+
"s3://", sagemaker_session.default_bucket(), "integ-test-test-s3-list", my_uuid
43+
)
44+
file_1_s3_uri = os.path.join(base_s3_uri, file_1_name)
45+
file_2_s3_uri = os.path.join(base_s3_uri, file_2_name)
46+
47+
S3Uploader.upload_string_as_file_body(
48+
body=file_1_body,
49+
desired_s3_uri=file_1_s3_uri,
50+
kms_key=s3_files_kms_key,
51+
session=sagemaker_session,
52+
)
53+
54+
S3Uploader.upload_string_as_file_body(
55+
body=file_2_body,
56+
desired_s3_uri=file_2_s3_uri,
57+
kms_key=s3_files_kms_key,
58+
session=sagemaker_session,
59+
)
60+
61+
s3_uris = S3Downloader.list(s3_uri=base_s3_uri, session=sagemaker_session)
62+
63+
assert file_1_name in s3_uris[0]
64+
assert file_2_name in s3_uris[1]
65+
66+
assert file_1_body == S3Downloader.read_file(s3_uri=s3_uris[0], session=sagemaker_session)
67+
assert file_2_body == S3Downloader.read_file(s3_uri=s3_uris[1], session=sagemaker_session)

0 commit comments

Comments
 (0)