Skip to content

Commit 7a78b7f

Browse files
committed
Allow Passing Tags in E/T/TC Create methods
1 parent 9af182d commit 7a78b7f

File tree

10 files changed

+91
-7
lines changed

10 files changed

+91
-7
lines changed

src/smexperiments/experiment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@ class Experiment(_base_types.Record):
4343
Attributes:
4444
experiment_name (str): The name of the experiment. The name must be unique within an account.
4545
description (str): A description of the experiment.
46+
tags (List[dict[str, str]]): A list of tags to associate with the experiment.
4647
"""
4748

4849
experiment_name = None
4950
description = None
51+
tags = None
5052

5153
_boto_create_method = "create_experiment"
5254
_boto_load_method = "describe_experiment"
@@ -92,7 +94,7 @@ def load(cls, experiment_name, sagemaker_boto_client=None):
9294
)
9395

9496
@classmethod
95-
def create(cls, experiment_name=None, description=None, sagemaker_boto_client=None):
97+
def create(cls, experiment_name=None, description=None, tags=None, sagemaker_boto_client=None):
9698
"""
9799
Create a new experiment in SageMaker and return an ``Experiment`` object.
98100
@@ -101,6 +103,7 @@ def create(cls, experiment_name=None, description=None, sagemaker_boto_client=No
101103
experiment_description: (str, optional): Description of the experiment
102104
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
103105
supplied, a default boto3 client will be created and used.
106+
tags (List[dict[str, str]]): A list of tags to associate with the experiment.
104107
105108
Returns:
106109
sagemaker.experiments.experiment.Experiment: A SageMaker ``Experiment`` object
@@ -109,6 +112,7 @@ def create(cls, experiment_name=None, description=None, sagemaker_boto_client=No
109112
cls._boto_create_method,
110113
experiment_name=experiment_name,
111114
description=description,
115+
tags=tags,
112116
sagemaker_boto_client=sagemaker_boto_client,
113117
)
114118

src/smexperiments/trial.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@ class Trial(_base_types.Record):
4343
Attributes:
4444
trial_name (str): The name of the trial.
4545
experiment_name (str): The name of the trial's experiment.
46+
tags (List[dict[str, str]]): A list of tags to associate with the trial.
4647
"""
4748

4849
trial_name = None
4950
experiment_name = None
51+
tags = None
5052

5153
_boto_create_method = "create_trial"
5254
_boto_load_method = "describe_trial"
@@ -96,15 +98,16 @@ def load(cls, trial_name, sagemaker_boto_client=None):
9698
)
9799

98100
@classmethod
99-
def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, trial_components=None):
101+
def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, trial_components=None, tags=None):
100102
"""Create a new trial and return a ``Trial`` object.
101103
102104
Args:
103105
experiment_name: (str): Name of the experiment to create this trial in.
104106
trial_name: (str, optional): Name of the Trial. If not specified, an auto-generated name will be used.
105107
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker.
106108
If not supplied, a default boto3 client will be created and used.
107-
trial_components (list): A list of trial component names, trial components, or trial component trackers
109+
trial_components (list): A list of trial component names, trial components, or trial component trackers.
110+
tags (List[dict[str, str]]): A list of tags to associate with the trial.
108111
109112
Returns:
110113
smexperiments.trial.Trial: A SageMaker ``Trial`` object
@@ -114,6 +117,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
114117
cls._boto_create_method,
115118
trial_name=trial_name,
116119
experiment_name=experiment_name,
120+
tags=tags,
117121
sagemaker_boto_client=sagemaker_boto_client,
118122
)
119123
if trial_components:

src/smexperiments/trial_component.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TrialComponent(_base_types.Record):
4343
parameters_to_remove (list): The hyperparameters to remove from the component.
4444
input_artifacts_to_remove (list): The input artifacts to remove from the component.
4545
output_artifacts_to_remove (list): The output artifacts to remove from the component.
46+
tags (List[dict[str, str]]): A list of tags to associate with the trial component.
4647
"""
4748

4849
trial_component_name = None
@@ -63,6 +64,7 @@ class TrialComponent(_base_types.Record):
6364
parameters_to_remove = None
6465
input_artifacts_to_remove = None
6566
output_artifacts_to_remove = None
67+
tags = None
6668

6769
_boto_load_method = "describe_trial_component"
6870
_boto_create_method = "create_trial_component"
@@ -125,7 +127,7 @@ def load(cls, trial_component_name, sagemaker_boto_client=None):
125127
return trial_component
126128

127129
@classmethod
128-
def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=None):
130+
def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_boto_client=None):
129131
"""Create a trial component and return a ``TrialComponent`` object representing it.
130132
131133
Returns:
@@ -136,6 +138,7 @@ def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=N
136138
cls._boto_create_method,
137139
trial_component_name=trial_component_name,
138140
display_name=display_name,
141+
tags=tags,
139142
sagemaker_boto_client=sagemaker_boto_client,
140143
)
141144

tests/conftest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from smexperiments import experiment, trial, trial_component
3232
from tests.helpers import name, names
3333

34+
TAGS = [{"Key": "some-key", "Value": "some-value"}]
35+
3436

3537
def pytest_addoption(parser):
3638
parser.addoption("--boto-model-file", action="store", default=None)
@@ -93,7 +95,7 @@ def experiment_obj(sagemaker_boto_client):
9395
description = "{}-{}".format("description", str(uuid.uuid4()))
9496
boto3.set_stream_logger("", logging.INFO)
9597
experiment_obj = experiment.Experiment.create(
96-
experiment_name=name(), description=description, sagemaker_boto_client=sagemaker_boto_client
98+
experiment_name=name(), description=description, sagemaker_boto_client=sagemaker_boto_client, tags=TAGS
9799
)
98100
yield experiment_obj
99101
time.sleep(0.5)
@@ -103,7 +105,10 @@ def experiment_obj(sagemaker_boto_client):
103105
@pytest.fixture
104106
def trial_obj(sagemaker_boto_client, experiment_obj):
105107
trial_obj = trial.Trial.create(
106-
trial_name=name(), experiment_name=experiment_obj.experiment_name, sagemaker_boto_client=sagemaker_boto_client
108+
trial_name=name(),
109+
experiment_name=experiment_obj.experiment_name,
110+
tags=TAGS,
111+
sagemaker_boto_client=sagemaker_boto_client,
107112
)
108113
yield trial_obj
109114
time.sleep(0.5)
@@ -113,7 +118,7 @@ def trial_obj(sagemaker_boto_client, experiment_obj):
113118
@pytest.fixture
114119
def trial_component_obj(sagemaker_boto_client):
115120
trial_component_obj = trial_component.TrialComponent.create(
116-
trial_component_name=name(), sagemaker_boto_client=sagemaker_boto_client
121+
trial_component_name=name(), sagemaker_boto_client=sagemaker_boto_client, tags=TAGS,
117122
)
118123
yield trial_component_obj
119124
time.sleep(0.5)

tests/integ/test_experiment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def test_create_delete(experiment_obj):
2222
assert experiment_obj.experiment_name
2323

2424

25+
def test_create_tags(experiment_obj, sagemaker_boto_client):
26+
while True:
27+
actual_tags = sagemaker_boto_client.list_tags(ResourceArn=experiment_obj.experiment_arn)["Tags"]
28+
if actual_tags:
29+
break
30+
assert actual_tags == experiment_obj.tags
31+
32+
2533
def test_save(experiment_obj):
2634
description = name()
2735
experiment_obj.description = description

tests/integ/test_trial.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def test_create_delete(trial_obj):
2222
assert trial_obj.trial_name
2323

2424

25+
def test_create_tags(trial_obj, sagemaker_boto_client):
26+
while True:
27+
actual_tags = sagemaker_boto_client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"]
28+
if actual_tags:
29+
break
30+
assert actual_tags == trial_obj.tags
31+
32+
2533
def test_list(trials, sagemaker_boto_client):
2634
slack = datetime.timedelta(minutes=1)
2735
now = datetime.datetime.now(datetime.timezone.utc)

tests/integ/test_trial_component.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ def test_create_delete(trial_component_obj):
2323
assert trial_component_obj.trial_component_name
2424

2525

26+
def test_create_tags(trial_component_obj, sagemaker_boto_client):
27+
while True:
28+
actual_tags = sagemaker_boto_client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"]
29+
if actual_tags:
30+
break
31+
assert actual_tags == trial_component_obj.tags
32+
33+
2634
def test_save(trial_component_obj, sagemaker_boto_client):
2735
trial_component_obj.display_name = str(uuid.uuid4())
2836
trial_component_obj.status = api_types.TrialComponentStatus(primary_status="InProgress", message="Message")

tests/unit/test_experiment.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,26 @@ def test_load(sagemaker_boto_client):
4040

4141
def test_create(sagemaker_boto_client):
4242
sagemaker_boto_client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
43+
tags = {"Key": "foo", "Value": "bar"}
4344
experiment_obj = experiment.Experiment.create(
4445
experiment_name="name-value", sagemaker_boto_client=sagemaker_boto_client
4546
)
4647
assert experiment_obj.experiment_name == "name-value"
4748
sagemaker_boto_client.create_experiment.assert_called_with(ExperimentName="name-value")
4849

4950

51+
def test_create_with_tags(sagemaker_boto_client):
52+
sagemaker_boto_client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
53+
tags = [{"Key": "foo", "Value": "bar"}]
54+
experiment_obj = experiment.Experiment.create(
55+
experiment_name="name-value", sagemaker_boto_client=sagemaker_boto_client, tags=tags
56+
)
57+
assert experiment_obj.experiment_name == "name-value"
58+
sagemaker_boto_client.create_experiment.assert_called_with(
59+
ExperimentName="name-value", Tags=[{"Key": "foo", "Value": "bar"}]
60+
)
61+
62+
5063
def test_list(sagemaker_boto_client, datetime_obj):
5164
sagemaker_boto_client.list_experiments.return_value = {
5265
"ExperimentSummaries": [

tests/unit/test_trial.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,24 @@ def test_create(sagemaker_boto_client):
5050
)
5151

5252

53+
def test_create_with_tags(sagemaker_boto_client):
54+
sagemaker_boto_client.create_trial.return_value = {
55+
"Arn": "arn:aws:1234",
56+
"TrialName": "name-value",
57+
}
58+
tags = [{"Key": "foo", "Value": "bar"}]
59+
trial_obj = trial.Trial.create(
60+
trial_name="name-value",
61+
experiment_name="experiment-name-value",
62+
sagemaker_boto_client=sagemaker_boto_client,
63+
tags=tags,
64+
)
65+
assert trial_obj.trial_name == "name-value"
66+
sagemaker_boto_client.create_trial.assert_called_with(
67+
TrialName="name-value", ExperimentName="experiment-name-value", Tags=[{"Key": "foo", "Value": "bar"}]
68+
)
69+
70+
5371
def test_create_no_name(sagemaker_boto_client):
5472
sagemaker_boto_client.create_trial.return_value = {}
5573
trial.Trial.create(experiment_name="experiment-name-value", sagemaker_boto_client=sagemaker_boto_client)

tests/unit/test_trial_component.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def test_create(sagemaker_boto_client):
3535
assert "bazz" == obj.trial_component_arn
3636

3737

38+
def test_create_with_tags(sagemaker_boto_client):
39+
sagemaker_boto_client.create_trial_component.return_value = {
40+
"TrialComponentArn": "bazz",
41+
}
42+
tags = [{"Key": "foo", "Value": "bar"}]
43+
obj = trial_component.TrialComponent.create(
44+
trial_component_name="foo", display_name="bar", sagemaker_boto_client=sagemaker_boto_client, tags=tags
45+
)
46+
sagemaker_boto_client.create_trial_component.assert_called_with(
47+
TrialComponentName="foo", DisplayName="bar", Tags=[{"Key": "foo", "Value": "bar"}]
48+
)
49+
50+
3851
def test_load(sagemaker_boto_client):
3952
now = datetime.datetime.now(datetime.timezone.utc)
4053

0 commit comments

Comments
 (0)