Skip to content

Commit 1d52141

Browse files
committed
Log all files in a directory
1 parent f3f039f commit 1d52141

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

src/smexperiments/tracker.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import botocore
2121
import json
2222
from smexperiments._utils import get_module
23+
from os.path import join
2324

2425
import dateutil
2526

@@ -264,7 +265,7 @@ def log_output(self, name, value, media_type=None):
264265
Examples
265266
.. code-block:: python
266267
267-
# log input dataset s3 location
268+
# log output dataset s3 location
268269
my_tracker.log_output(name='prediction', value='s3://outputs/path')
269270
270271
Args:
@@ -276,6 +277,26 @@ def log_output(self, name, value, media_type=None):
276277
raise ValueError("Cannot add more than 30 output_artifacts under tracker trial_component")
277278
self.trial_component.output_artifacts[name] = api_types.TrialComponentArtifact(value, media_type=media_type)
278279

280+
def log_artifacts(self, directory, media_type=None):
281+
"""Upload all the files under the directory to s3 and store it as artifacts in this trial component. The file
282+
name is used as the artifact name
283+
284+
Examples
285+
.. code-block:: python
286+
287+
# log local artifact
288+
my_tracker.log_artifact(directory='/local/path)
289+
290+
Args:
291+
directory (str): The directory of the local files to upload.
292+
media_type (str, optional): The MediaType (MIME type) of the file. If not specified, this library
293+
will attempt to infer the media type from the file extension of ``file_path``.
294+
"""
295+
for dir_file in os.listdir(directory):
296+
file_path = join(directory, dir_file)
297+
artifact_name = os.path.splitext(dir_file)[0]
298+
self.log_artifact(file_path=file_path, name=artifact_name, media_type=media_type)
299+
279300
def log_artifact(self, file_path, name=None, media_type=None):
280301
"""Legacy overload method to prevent breaking existing code.
281302

tests/integ/test_tracker.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,32 @@ def test_log_artifact(trial_component_obj, bucket, tempdir, sagemaker_boto_clien
9595
assert prefix in loaded.output_artifacts[artifact_name].value
9696

9797

98+
def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_client):
99+
prefix = name()
100+
file_contents = "happy monkey monkey"
101+
file_path = os.path.join(tempdir, "foo.txt")
102+
file_path1 = os.path.join(tempdir, "bar.txt")
103+
with open(file_path, "w") as foo_file:
104+
foo_file.write(file_contents)
105+
with open(file_path1, "w") as bar_file:
106+
bar_file.write(file_contents)
107+
108+
with tracker.Tracker.load(
109+
trial_component_obj.trial_component_name,
110+
artifact_bucket=bucket,
111+
artifact_prefix=prefix,
112+
sagemaker_boto_client=sagemaker_boto_client,
113+
) as tracker_obj:
114+
tracker_obj.log_artifacts(tempdir)
115+
loaded = trial_component.TrialComponent.load(
116+
trial_component_name=trial_component_obj.trial_component_name, sagemaker_boto_client=sagemaker_boto_client
117+
)
118+
assert "text/plain" == loaded.output_artifacts["foo"].media_type
119+
assert prefix in loaded.output_artifacts["foo"].value
120+
assert "text/plain" == loaded.output_artifacts["bar"].media_type
121+
assert prefix in loaded.output_artifacts["bar"].value
122+
123+
98124
def test_create_default_bucket(boto3_session):
99125
bucket_name_prefix = _utils.name("sm-test")
100126
bucket = _utils.get_or_create_default_bucket(boto3_session, default_bucket_prefix=bucket_name_prefix)

0 commit comments

Comments
 (0)