Skip to content

Commit fed6774

Browse files
committed
add list trial components method on trial object
add trial and experiment name filtering parameters to list trial components
1 parent 0ca1fb3 commit fed6774

File tree

7 files changed

+277
-7
lines changed

7 files changed

+277
-7
lines changed

src/smexperiments/trial.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,36 @@ def remove_trial_component(self, tc):
173173
self.sagemaker_boto_client.disassociate_trial_component(
174174
TrialName=self.trial_name, TrialComponentName=trial_component_name
175175
)
176+
177+
def list_trial_components(
178+
self,
179+
created_before=None,
180+
created_after=None,
181+
sort_by=None,
182+
sort_order=None,
183+
max_results=None,
184+
next_token=None):
185+
"""List trial components in this trial matching the specified criteria.
186+
187+
Args:
188+
created_before (datetime.datetime, optional): Return trials created before this instant.
189+
created_after (datetime.datetime, optional): Return trials created after this instant.
190+
sort_by (str, optional): Which property to sort results by. One of 'Name',
191+
'CreationTime'.
192+
sort_order (str, optional): One of 'Ascending', or 'Descending'.
193+
max_results (int, optional): maximum number of trial components to retrieve
194+
next_token (str, optional): token for next page of results
195+
Returns:
196+
collections.Iterator[smexperiments.api_types.TrialComponentSummary] : An iterator over
197+
trials matching the criteria.
198+
"""
199+
return trial_component.TrialComponent.list(
200+
trial_name=self.trial_name,
201+
created_before=created_before,
202+
created_after=created_after,
203+
sort_by=sort_by,
204+
sort_order=sort_order,
205+
max_results=max_results,
206+
next_token=next_token,
207+
sagemaker_boto_client=self.sagemaker_boto_client,
208+
)

src/smexperiments/trial_component.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,18 @@ def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=N
110110
sagemaker_boto_client=sagemaker_boto_client)
111111

112112
@classmethod
113-
def list(cls, source_arn=None, created_before=None, created_after=None,
114-
sort_by=None, sort_order=None, sagemaker_boto_client=None):
113+
def list(
114+
cls,
115+
source_arn=None,
116+
created_before=None,
117+
created_after=None,
118+
sort_by=None,
119+
sort_order=None,
120+
sagemaker_boto_client=None,
121+
trial_name=None,
122+
experiment_name=None,
123+
max_results=None,
124+
next_token=None):
115125
"""
116126
Return a list of trial component summaries.
117127
@@ -124,6 +134,11 @@ def list(cls, source_arn=None, created_before=None, created_after=None,
124134
sort_order (str, optional): One of 'Ascending', or 'Descending'.
125135
sagemaker_boto_client (SageMaker.Client, optional) : Boto3 client for SageMaker.
126136
If not supplied, a default boto3 client will be created and used.
137+
trial_name (str, optional): Name of a Trial
138+
experiment_name (str, optional): Name of an Experiment
139+
max_results (int, optional): maximum number of trial components to retrieve
140+
next_token (str, optional): token for next page of results
141+
127142
Returns:
128143
collections.Iterator[smexperiments.api_types.TrialComponentSummary]: An iterator
129144
over ``TrialComponentSummary`` objects.
@@ -137,4 +152,8 @@ def list(cls, source_arn=None, created_before=None, created_after=None,
137152
created_after=created_after,
138153
sort_by=sort_by,
139154
sort_order=sort_order,
140-
sagemaker_boto_client=sagemaker_boto_client)
155+
sagemaker_boto_client=sagemaker_boto_client,
156+
trial_name=trial_name,
157+
experiment_name=experiment_name,
158+
max_results=max_results,
159+
next_token=next_token)

tests/integ-jobs/test_track_from_training_job.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313

1414
import sys
15-
1615
import boto3
1716

1817
from tests.helpers import *

tests/integ/__init__.py

Whitespace-only changes.

tests/unit/__init__.py

Whitespace-only changes.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import pytest
14+
import unittest.mock
15+
import datetime
16+
17+
from smexperiments import trial, api_types
18+
19+
20+
@pytest.fixture
21+
def sagemaker_boto_client():
22+
return unittest.mock.Mock()
23+
24+
25+
@pytest.fixture
26+
def datetime_obj():
27+
return datetime.datetime(2017, 6, 16, 15, 55, 0)
28+
29+
def test_list_trial_components(sagemaker_boto_client, datetime_obj):
30+
sagemaker_boto_client.list_trial_components.return_value = {
31+
"TrialComponentSummaries": [
32+
{
33+
"TrialComponentName": "trial-component-1",
34+
"CreationTime": datetime_obj,
35+
"LastModifiedTime": datetime_obj,
36+
},
37+
{
38+
"TrialComponentName": "trial-component-2",
39+
"CreationTime": datetime_obj,
40+
"LastModifiedTime": datetime_obj,
41+
},
42+
]
43+
}
44+
expected = [
45+
api_types.TrialComponentSummary(
46+
trial_component_name="trial-component-1",
47+
creation_time=datetime_obj,
48+
last_modified_time=datetime_obj,
49+
),
50+
api_types.TrialComponentSummary(
51+
trial_component_name="trial-component-2",
52+
creation_time=datetime_obj,
53+
last_modified_time=datetime_obj,
54+
),
55+
]
56+
57+
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
58+
59+
assert expected == list(trial_obj.list_trial_components())
60+
61+
62+
def test_list_trial_components_empty(sagemaker_boto_client):
63+
sagemaker_boto_client.list_trial_components.return_value = {"TrialComponentSummaries": []}
64+
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
65+
assert list(trial_obj.list_trial_components()) == []
66+
67+
68+
def test_list_trial_components_single(sagemaker_boto_client, datetime_obj):
69+
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
70+
sagemaker_boto_client.list_trial_components.return_value = {
71+
"TrialComponentSummaries": [
72+
{
73+
"TrialComponentName": "trial-component-foo",
74+
"CreationTime": datetime_obj,
75+
"LastModifiedTime": datetime_obj
76+
}
77+
]
78+
}
79+
80+
assert list(trial_obj.list_trial_components()) == [
81+
api_types.TrialComponentSummary(
82+
trial_component_name="trial-component-foo",
83+
creation_time=datetime_obj,
84+
last_modified_time=datetime_obj
85+
)
86+
]
87+
88+
89+
def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj):
90+
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
91+
sagemaker_boto_client.list_trial_components.return_value = {
92+
"TrialComponentSummaries": [
93+
{"TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
94+
{"TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
95+
]
96+
}
97+
98+
assert list(trial_obj.list_trial_components()) == [
99+
api_types.TrialComponentSummary(
100+
trial_component_name="trial-component-foo-1",
101+
creation_time=datetime_obj,
102+
last_modified_time=datetime_obj
103+
),
104+
api_types.TrialComponentSummary(
105+
trial_component_name="trial-component-foo-2",
106+
creation_time=datetime_obj,
107+
last_modified_time=datetime_obj
108+
),
109+
]
110+
111+
112+
def test_next_token(sagemaker_boto_client, datetime_obj):
113+
trial_obj = trial.Trial(sagemaker_boto_client)
114+
sagemaker_boto_client.list_trial_components.side_effect = [
115+
{
116+
"TrialComponentSummaries": [
117+
{
118+
"TrialComponentName": "trial-component-foo-1",
119+
"CreationTime": datetime_obj,
120+
"LastModifiedTime": datetime_obj,
121+
},
122+
{
123+
"TrialComponentName": "trial-component-foo-2",
124+
"CreationTime": datetime_obj,
125+
"LastModifiedTime": datetime_obj,
126+
},
127+
],
128+
"NextToken": "foo",
129+
},
130+
{
131+
"TrialComponentSummaries": [
132+
{
133+
"TrialComponentName": "trial-component-foo-3",
134+
"CreationTime": datetime_obj,
135+
"LastModifiedTime": datetime_obj,
136+
}
137+
]
138+
},
139+
]
140+
141+
assert list(trial_obj.list_trial_components()) == [
142+
api_types.TrialComponentSummary(
143+
trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj
144+
),
145+
api_types.TrialComponentSummary(
146+
trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj
147+
),
148+
api_types.TrialComponentSummary(
149+
trial_component_name="trial-component-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj
150+
),
151+
]
152+
153+
sagemaker_boto_client.list_trial_components.assert_any_call(**{})
154+
sagemaker_boto_client.list_trial_components.assert_any_call(NextToken="foo")
155+
156+
157+
def test_list_trial_components_call_args(sagemaker_boto_client):
158+
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
159+
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
160+
trial_name = 'foo-trial'
161+
next_token = 'thetoken'
162+
max_results = 99
163+
164+
trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
165+
trial_obj.trial_name=trial_name
166+
167+
sagemaker_boto_client.list_trial_components.return_value = {}
168+
assert [] == list(
169+
trial_obj.list_trial_components(
170+
created_after=created_after,
171+
created_before=created_before,
172+
next_token=next_token,
173+
max_results=max_results)
174+
)
175+
sagemaker_boto_client.list_trial_components.assert_called_with(
176+
CreatedBefore=created_before,
177+
CreatedAfter=created_after,
178+
TrialName=trial_name,
179+
NextToken=next_token,
180+
MaxResults=max_results,
181+
)

tests/unit/test_trial_component.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,12 @@ def test_list(sagemaker_boto_client):
165165
last_modified_by={}
166166
) for i in range(20)
167167
]
168-
result = list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client,
169-
source_arn='foo', sort_by='CreationTime',
170-
sort_order='Ascending'))
168+
result = list(trial_component.TrialComponent.list(
169+
sagemaker_boto_client=sagemaker_boto_client,
170+
source_arn='foo',
171+
sort_by='CreationTime',
172+
sort_order='Ascending'))
173+
171174
assert expected == result
172175
expected_calls= [unittest.mock.call(SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo'),
173176
unittest.mock.call(NextToken='100', SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo')]
@@ -181,6 +184,41 @@ def test_list_empty(sagemaker_boto_client):
181184
assert [] == list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client))
182185

183186

187+
def test_list_trial_components_call_args(sagemaker_boto_client):
188+
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
189+
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
190+
trial_name = 'foo-trial'
191+
experiment_name = 'foo-experiment'
192+
next_token = 'thetoken'
193+
max_results = 99
194+
195+
sagemaker_boto_client.list_trial_components.return_value = {}
196+
assert [] == list(
197+
trial_component.TrialComponent.list(
198+
sagemaker_boto_client=sagemaker_boto_client,
199+
trial_name=trial_name,
200+
experiment_name=experiment_name,
201+
created_before=created_before,
202+
created_after=created_after,
203+
next_token=next_token,
204+
max_results=max_results,
205+
sort_by='CreationTime',
206+
sort_order='Ascending')
207+
)
208+
209+
expected_calls = [unittest.mock.call(
210+
TrialName='foo-trial',
211+
ExperimentName='foo-experiment',
212+
CreatedBefore=created_before,
213+
CreatedAfter=created_after,
214+
SortBy='CreationTime',
215+
SortOrder='Ascending',
216+
NextToken='thetoken',
217+
MaxResults=99,
218+
)]
219+
assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls
220+
221+
184222
def test_save(sagemaker_boto_client):
185223
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name='foo', display_name='bar')
186224
sagemaker_boto_client.update_trial_component.return_value = {}

0 commit comments

Comments
 (0)