Skip to content

Commit f3f039f

Browse files
committed
handle max input/output artifacts on TC
1 parent 4ff9f09 commit f3f039f

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

src/smexperiments/tracker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def log_input(self, name, value, media_type=None):
252252
value (str): The value.
253253
media_type (str, optional): The MediaType (MIME type) of the value
254254
"""
255+
if len(self.trial_component.input_artifacts) >= 30:
256+
raise ValueError("Cannot add more than 30 input_artifacts under tracker trial_component.")
255257
self.trial_component.input_artifacts[name] = api_types.TrialComponentArtifact(value, media_type=media_type)
256258

257259
def log_output(self, name, value, media_type=None):
@@ -270,6 +272,8 @@ def log_output(self, name, value, media_type=None):
270272
value (str): The value.
271273
media_type (str, optional): The MediaType (MIME type) of the value.
272274
"""
275+
if len(self.trial_component.output_artifacts) >= 30:
276+
raise ValueError("Cannot add more than 30 output_artifacts under tracker trial_component")
273277
self.trial_component.output_artifacts[name] = api_types.TrialComponentArtifact(value, media_type=media_type)
274278

275279
def log_artifact(self, file_path, name=None, media_type=None):
@@ -303,6 +307,8 @@ def log_output_artifact(self, file_path, name=None, media_type=None):
303307
media_type (str, optional): The MediaType (MIME type) of the file. If not specified, this library
304308
will attempt to infer the media type from the file extension of ``file_path``.
305309
"""
310+
if len(self.trial_component.output_artifacts) >= 30:
311+
raise ValueError("Cannot add more than 30 output_artifacts under tracker trial_component")
306312
media_type = media_type or _guess_media_type(file_path)
307313
name = name or _resolve_artifact_name(file_path)
308314
s3_uri, etag = self._artifact_uploader.upload_artifact(file_path)
@@ -326,6 +332,8 @@ def log_input_artifact(self, file_path, name=None, media_type=None):
326332
media_type (str, optional): The MediaType (MIME type) of the file. If not specified, this library
327333
will attempt to infer the media type from the file extension of ``file_path``.
328334
"""
335+
if len(self.trial_component.input_artifacts) >= 30:
336+
raise ValueError("Cannot add more than 30 input_artifacts under tracker trial_component.")
329337
media_type = media_type or _guess_media_type(file_path)
330338
name = name or _resolve_artifact_name(file_path)
331339
s3_uri, etag = self._artifact_uploader.upload_artifact(file_path)

tests/unit/test_tracker.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,56 @@ def test_log_input_artifact(under_test):
292292
assert "text/plain" == under_test.trial_component.input_artifacts["foo.txt"].media_type
293293

294294

295+
def test_log_inputs_error(under_test):
296+
for index in range(0, 30):
297+
file_path = "foo" + str(index) + ".txt"
298+
under_test.trial_component.input_artifacts[file_path] = {
299+
"foo": api_types.TrialComponentArtifact(value="baz" + str(index), media_type="text/text")
300+
}
301+
with pytest.raises(ValueError):
302+
under_test.log_input("foo.txt", "name", "whizz/bang")
303+
304+
305+
def test_log_outputs(under_test):
306+
for index in range(0, 30):
307+
file_path = "foo" + str(index) + ".txt"
308+
under_test.trial_component.output_artifacts[file_path] = {
309+
"foo": api_types.TrialComponentArtifact(value="baz" + str(index), media_type="text/text")
310+
}
311+
with pytest.raises(ValueError):
312+
under_test.log_output("foo.txt", "name", "whizz/bang")
313+
314+
315+
def test_log_multiple_input_artifact(under_test):
316+
for index in range(0, 30):
317+
file_path = "foo" + str(index) + ".txt"
318+
under_test._artifact_uploader.upload_artifact.return_value = (
319+
"s3uri_value" + str(index),
320+
"etag_value" + str(index),
321+
)
322+
under_test.log_input_artifact(file_path, "name" + str(index), "whizz/bang" + str(index))
323+
under_test._artifact_uploader.upload_artifact.assert_called_with(file_path)
324+
325+
under_test._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
326+
with pytest.raises(ValueError):
327+
under_test.log_input_artifact("foo.txt", "name", "whizz/bang")
328+
329+
330+
def test_log_multiple_output_artifact(under_test):
331+
for index in range(0, 30):
332+
file_path = "foo" + str(index) + ".txt"
333+
under_test._artifact_uploader.upload_artifact.return_value = (
334+
"s3uri_value" + str(index),
335+
"etag_value" + str(index),
336+
)
337+
under_test.log_output_artifact(file_path, "name" + str(index), "whizz/bang" + str(index))
338+
under_test._artifact_uploader.upload_artifact.assert_called_with(file_path)
339+
340+
under_test._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
341+
with pytest.raises(ValueError):
342+
under_test.log_output_artifact("foo.txt", "name", "whizz/bang")
343+
344+
295345
def test_log_pr_curve(under_test):
296346

297347
y_true = [0, 0, 1, 1]

0 commit comments

Comments
 (0)