Skip to content

Commit 9363166

Browse files
committed
list trials under trial component
1 parent 1d52141 commit 9363166

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

src/smexperiments/trial_component.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Contains the TrialComponent class."""
14-
from smexperiments import _base_types, api_types
14+
from smexperiments import _base_types, api_types, trial
1515
import time
1616

1717

@@ -130,11 +130,11 @@ def delete(self, force_disassociate=None):
130130
)
131131

132132
# Disassociate the trials and trial components
133-
for trial in list_trials_response["TrialSummaries"]:
133+
for per_trial in list_trials_response["TrialSummaries"]:
134134
# to prevent DisassociateTrialComponent throttling
135135
time.sleep(1.2)
136136
self.sagemaker_boto_client.disassociate_trial_component(
137-
TrialName=trial["TrialName"], TrialComponentName=self.trial_component_name
137+
TrialName=per_trial["TrialName"], TrialComponentName=self.trial_component_name
138138
)
139139

140140
if "NextToken" in list_trials_response:
@@ -144,6 +144,17 @@ def delete(self, force_disassociate=None):
144144

145145
return self._invoke_api(self._boto_delete_method, self._boto_delete_members)
146146

147+
def list_trials(self):
148+
"""
149+
Load a list of trials that contains the same trial component name
150+
151+
Returns:
152+
A list of trials that contains the same trial component name
153+
"""
154+
return trial.Trial.list(
155+
trial_component_name=self.trial_component_name, sagemaker_boto_client=self.sagemaker_boto_client
156+
)
157+
147158
@classmethod
148159
def load(cls, trial_component_name, sagemaker_boto_client=None):
149160
"""Load an existing trial component and return an ``TrialComponent`` object representing it.

tests/integ/test_trial_component.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,27 @@ def test_list_sort(trial_components, sagemaker_boto_client):
112112
assert trial_component_names # sanity test
113113

114114

115+
def test_list_empty_trials(experiment_obj, trial_component_obj):
116+
actual_value = trial_component_obj.list_trials()
117+
assert len(list(actual_value)) == 0
118+
119+
120+
def test_list_trials(trial_obj, trial_component_obj, sagemaker_boto_client):
121+
trial_obj.add_trial_component(trial_component_obj)
122+
actual_value = list(trial_component_obj.list_trials())
123+
124+
assert len(actual_value) == 1
125+
assert actual_value[0].trial_name == trial_obj.trial_name
126+
trial_obj.remove_trial_component(trial_component_obj)
127+
128+
trial_components = list(
129+
trial_component.TrialComponent.list(
130+
sagemaker_boto_client=sagemaker_boto_client, trial_name=trial_obj.trial_name
131+
)
132+
)
133+
assert 0 == len(trial_components)
134+
135+
115136
def test_list_trial_components_by_experiment(experiment_obj, trial_component_obj, sagemaker_boto_client):
116137
trial_obj = experiment_obj.create_trial()
117138
trial_obj.add_trial_component(trial_component_obj)

tests/unit/test_trial_component.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from smexperiments import trial_component, api_types
13+
from smexperiments import trial_component, api_types, trial
1414

1515
import datetime
1616
import pytest
@@ -285,6 +285,31 @@ def test_delete_with_force_disassociate(sagemaker_boto_client):
285285
sagemaker_boto_client.delete_trial_component.assert_called_with(TrialComponentName="foo")
286286

287287

288+
def test_list_trials(sagemaker_boto_client):
289+
sagemaker_boto_client.list_trials.return_value = {
290+
"TrialSummaries": [
291+
{
292+
"TrialName": "trial-1",
293+
"CreationTime": None,
294+
"LastModifiedTime": None,
295+
},
296+
{
297+
"TrialName": "trial-2",
298+
"CreationTime": None,
299+
"LastModifiedTime": None,
300+
},
301+
]
302+
}
303+
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", discompplay_name="bar")
304+
305+
expected = [
306+
api_types.TrialSummary(trial_name="trial-1", creation_time=None, last_modified_time=None),
307+
api_types.TrialSummary(trial_name="trial-2", creation_time=None, last_modified_time=None),
308+
]
309+
assert expected == list(obj.list_trials())
310+
sagemaker_boto_client.list_trials.assert_called_with(TrialComponentName="foo")
311+
312+
288313
def test_boto_ignore():
289314
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar")
290315
assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]

0 commit comments

Comments
 (0)