Skip to content

Commit a5e8f78

Browse files
authored
Merge pull request #79 from yzhu0/addRemoveTCSummary
fix: able to add and remove trail component summary
2 parents 204934a + 539dd64 commit a5e8f78

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/smexperiments/trial.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,16 @@ def add_trial_component(self, tc):
202202
A trial component may belong to many trials and a trial may have many trial components.
203203
204204
Args:
205-
tc: (tracker.Tracker|trial_component.TrialComponent|str) The trial component to add. Can be
206-
one of a Tracker instance, a TrialComponent instance, or a string containing the name of
205+
tc: (tracker.Tracker|trial_component.TrialComponent|api_types.TrialComponentSummary|str) The trial component
206+
to add. Can be one of a Tracker instance, a TrialComponent instance, or a string containing the name of
207207
the trial component to add.
208208
"""
209209
if isinstance(tc, tracker.Tracker):
210210
trial_component_name = tc.trial_component.trial_component_name
211211
elif isinstance(tc, trial_component.TrialComponent):
212212
trial_component_name = tc.trial_component_name
213+
elif isinstance(tc, api_types.TrialComponentSummary):
214+
trial_component_name = tc.trial_component_name
213215
else:
214216
trial_component_name = str(tc)
215217
self.sagemaker_boto_client.associate_trial_component(
@@ -220,14 +222,16 @@ def remove_trial_component(self, tc):
220222
"""Remove the specified trial component from this trial.
221223
222224
Args:
223-
tc: (tracker.Tracker|trial_component.TrialComponent|str) The trial component to remove. Can be
224-
one of a Tracker instance, a TrialComponent instance, or a string containing the name of
225-
the trial component to remove.
225+
tc: (tracker.Tracker|trial_component.TrialComponent|api_types.TrialComponentSummary|str) The trial
226+
component to remove. Can be one of a Tracker instance, a TrialComponent instance, or a string containing
227+
the name of the trial component to remove.
226228
"""
227229
if isinstance(tc, tracker.Tracker):
228230
trial_component_name = tc.trial_component.trial_component_name
229231
elif isinstance(tc, trial_component.TrialComponent):
230232
trial_component_name = tc.trial_component_name
233+
elif isinstance(tc, api_types.TrialComponentSummary):
234+
trial_component_name = tc.trial_component_name
231235
else:
232236
trial_component_name = str(tc)
233237
self.sagemaker_boto_client.disassociate_trial_component(

tests/unit/test_trial.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,31 @@ def test_add_trial_component(sagemaker_boto_client):
9999
)
100100

101101

102+
def test_add_trial_component_from_trial_component_summary(sagemaker_boto_client):
103+
t = trial.Trial(sagemaker_boto_client)
104+
t.trial_name = "bar"
105+
tcs = api_types.TrialComponentSummary()
106+
tcs.trial_component_name = "tcs-foo"
107+
t.add_trial_component(tcs)
108+
sagemaker_boto_client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tcs-foo")
109+
110+
102111
def test_remove_trial_component(sagemaker_boto_client):
103112
t = trial.Trial(sagemaker_boto_client)
104113
t.trial_name = "bar"
105114
t.remove_trial_component("foo")
106115
sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo")
107116

108117

118+
def test_remove_trial_component_from_trial_component_summary(sagemaker_boto_client):
119+
t = trial.Trial(sagemaker_boto_client)
120+
t.trial_name = "bar"
121+
tcs = api_types.TrialComponentSummary()
122+
tcs.trial_component_name = "tcs-foo"
123+
t.remove_trial_component(tcs)
124+
sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tcs-foo")
125+
126+
109127
def test_remove_trial_component_from_tracker(sagemaker_boto_client):
110128
t = trial.Trial(sagemaker_boto_client)
111129
t.trial_name = "bar"

0 commit comments

Comments
 (0)