Skip to content

Commit c93feef

Browse files
authored
change: cleanup experiments, trials, and trial components in integ tests (#1244)
1 parent 72b4c4b commit c93feef

File tree

1 file changed

+116
-135
lines changed

1 file changed

+116
-135
lines changed
Lines changed: 116 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,176 +1,157 @@
11
from __future__ import absolute_import
22

3-
import uuid
43
import time
4+
import uuid
5+
from contextlib import contextmanager
56

67
import pytest
78

89
from sagemaker.analytics import ExperimentAnalytics
910

1011

11-
@pytest.mark.canary_quick
12-
def test_experiment_analytics(sagemaker_session):
12+
@contextmanager
13+
def experiment(sagemaker_session):
1314
sm = sagemaker_session.sagemaker_client
15+
trials = {} # for resource cleanup
1416

1517
experiment_name = "experiment-" + str(uuid.uuid4())
16-
sm.create_experiment(ExperimentName=experiment_name)
17-
18-
for i in range(5):
19-
trial_name = "trial-" + str(uuid.uuid4())
20-
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
21-
trial_component_name = "tc-" + str(uuid.uuid4())
22-
sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training")
23-
sm.update_trial_component(
24-
TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}}
25-
)
26-
sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name)
27-
28-
time.sleep(15) # wait for search to get updated
29-
30-
analytics = ExperimentAnalytics(
31-
experiment_name=experiment_name, sagemaker_session=sagemaker_session
32-
)
18+
try:
19+
sm.create_experiment(ExperimentName=experiment_name)
3320

34-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
21+
# Search returns 10 results by default. Add 20 trials to verify pagination.
22+
for i in range(20):
23+
trial_name = "trial-" + str(uuid.uuid4())
24+
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
3525

26+
trial_component_name = "tc-" + str(uuid.uuid4())
27+
trials[trial_name] = trial_component_name
3628

37-
def test_experiment_analytics_pagination(sagemaker_session):
38-
sm = sagemaker_session.sagemaker_client
39-
40-
experiment_name = "experiment" + str(uuid.uuid4())
41-
sm.create_experiment(ExperimentName=experiment_name)
42-
43-
# Search returns 10 results by default. Add 20 trials to verify pagination,
44-
for i in range(20):
45-
trial_name = "trial-" + str(uuid.uuid4())
46-
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
47-
trial_component_name = "tc-" + str(uuid.uuid4())
48-
sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training")
49-
sm.update_trial_component(
50-
TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}}
51-
)
52-
sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name)
29+
sm.create_trial_component(
30+
TrialComponentName=trial_component_name, DisplayName="Training"
31+
)
32+
sm.update_trial_component(
33+
TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}}
34+
)
35+
sm.associate_trial_component(
36+
TrialComponentName=trial_component_name, TrialName=trial_name
37+
)
5338

54-
time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry
39+
time.sleep(15) # wait for search to get updated
5540

56-
analytics = ExperimentAnalytics(
57-
experiment_name=experiment_name, sagemaker_session=sagemaker_session
58-
)
41+
yield experiment_name
42+
finally:
43+
_delete_resources(sm, experiment_name, trials)
5944

60-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
61-
assert (
62-
len(analytics.dataframe()) > 10
63-
) # TODO [owen-t] Replace with == 20 and put test in retry block
6445

46+
@pytest.mark.canary_quick
47+
def test_experiment_analytics(sagemaker_session):
48+
with experiment(sagemaker_session) as experiment_name:
49+
analytics = ExperimentAnalytics(
50+
experiment_name=experiment_name, sagemaker_session=sagemaker_session
51+
)
6552

66-
def test_experiment_analytics_search_by_nested_filter(sagemaker_session):
67-
sm = sagemaker_session.sagemaker_client
53+
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
6854

69-
experiment_name = "experiment" + str(uuid.uuid4())
70-
sm.create_experiment(ExperimentName=experiment_name)
7155

72-
for i in range(20):
73-
trial_name = "trial-" + str(uuid.uuid4())
74-
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
75-
trial_component_name = "tc-" + str(uuid.uuid4())
76-
sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training")
77-
sm.update_trial_component(
78-
TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}}
56+
def test_experiment_analytics_pagination(sagemaker_session):
57+
with experiment(sagemaker_session) as experiment_name:
58+
analytics = ExperimentAnalytics(
59+
experiment_name=experiment_name, sagemaker_session=sagemaker_session
7960
)
80-
sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name)
8161

82-
time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry
62+
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
63+
assert (
64+
len(analytics.dataframe()) > 10
65+
) # TODO [owen-t] Replace with == 20 and put test in retry block
8366

84-
search_exp = {
85-
"Filters": [
86-
{"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name},
87-
{"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"},
88-
]
89-
}
9067

91-
analytics = ExperimentAnalytics(
92-
sagemaker_session=sagemaker_session, search_expression=search_exp
93-
)
68+
def test_experiment_analytics_search_by_nested_filter(sagemaker_session):
69+
with experiment(sagemaker_session) as experiment_name:
70+
search_exp = {
71+
"Filters": [
72+
{"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name},
73+
{"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"},
74+
]
75+
}
76+
77+
analytics = ExperimentAnalytics(
78+
sagemaker_session=sagemaker_session, search_expression=search_exp
79+
)
9480

95-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
96-
assert (
97-
len(analytics.dataframe()) > 5
98-
) # TODO [owen-t] Replace with == 10 and put test in retry block
81+
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
82+
assert (
83+
len(analytics.dataframe()) > 5
84+
) # TODO [owen-t] Replace with == 10 and put test in retry block
9985

10086

10187
def test_experiment_analytics_search_by_nested_filter_sort_ascending(sagemaker_session):
102-
sm = sagemaker_session.sagemaker_client
88+
with experiment(sagemaker_session) as experiment_name:
89+
search_exp = {
90+
"Filters": [
91+
{"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name},
92+
{"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"},
93+
]
94+
}
95+
96+
analytics = ExperimentAnalytics(
97+
sagemaker_session=sagemaker_session,
98+
search_expression=search_exp,
99+
sort_by="Parameters.hp1",
100+
sort_order="Ascending",
101+
)
102+
103+
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
104+
assert (
105+
len(analytics.dataframe()) > 5
106+
) # TODO [owen-t] Replace with == 10 and put test in retry block
107+
assert list(analytics.dataframe()["hp1"].values) == sorted(
108+
analytics.dataframe()["hp1"].values
109+
)
103110

104-
experiment_name = "experiment" + str(uuid.uuid4())
105-
sm.create_experiment(ExperimentName=experiment_name)
106111

107-
for i in range(20):
108-
trial_name = "trial-" + str(uuid.uuid4())
109-
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
110-
trial_component_name = "tc-" + str(uuid.uuid4())
111-
sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training")
112-
sm.update_trial_component(
113-
TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}}
112+
def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_session):
113+
with experiment(sagemaker_session) as experiment_name:
114+
search_exp = {
115+
"Filters": [
116+
{"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name},
117+
{"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"},
118+
]
119+
}
120+
121+
analytics = ExperimentAnalytics(
122+
sagemaker_session=sagemaker_session,
123+
search_expression=search_exp,
124+
sort_by="Parameters.hp1",
114125
)
115-
sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name)
116126

117-
time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry
127+
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
128+
assert (
129+
len(analytics.dataframe()) > 5
130+
) # TODO [owen-t] Replace with == 10 and put test in retry block
131+
assert (
132+
list(analytics.dataframe()["hp1"].values)
133+
== sorted(analytics.dataframe()["hp1"].values)[::-1]
134+
)
118135

119-
search_exp = {
120-
"Filters": [
121-
{"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name},
122-
{"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"},
123-
]
124-
}
125136

126-
analytics = ExperimentAnalytics(
127-
sagemaker_session=sagemaker_session,
128-
search_expression=search_exp,
129-
sort_by="Parameters.hp1",
130-
sort_order="Ascending",
131-
)
137+
def _delete_resources(sagemaker_client, experiment_name, trials):
138+
for trial, tc in trials.items():
139+
with _ignore_resource_not_found(sagemaker_client):
140+
sagemaker_client.disassociate_trial_component(TrialName=trial, TrialComponentName=tc)
132141

133-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
134-
assert (
135-
len(analytics.dataframe()) > 5
136-
) # TODO [owen-t] Replace with == 10 and put test in retry block
137-
assert list(analytics.dataframe()["hp1"].values) == sorted(analytics.dataframe()["hp1"].values)
142+
with _ignore_resource_not_found(sagemaker_client):
143+
sagemaker_client.delete_trial_component(TrialComponentName=tc)
138144

145+
with _ignore_resource_not_found(sagemaker_client):
146+
sagemaker_client.delete_trial(TrialName=trial)
139147

140-
def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_session):
141-
sm = sagemaker_session.sagemaker_client
148+
with _ignore_resource_not_found(sagemaker_client):
149+
sagemaker_client.delete_experiment(ExperimentName=experiment_name)
142150

143-
experiment_name = "experiment" + str(uuid.uuid4())
144-
sm.create_experiment(ExperimentName=experiment_name)
145151

146-
for i in range(20):
147-
trial_name = "trial-" + str(uuid.uuid4())
148-
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
149-
trial_component_name = "tc-" + str(uuid.uuid4())
150-
sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training")
151-
sm.update_trial_component(
152-
TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}}
153-
)
154-
sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name)
155-
156-
time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry
157-
158-
search_exp = {
159-
"Filters": [
160-
{"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name},
161-
{"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"},
162-
]
163-
}
164-
165-
analytics = ExperimentAnalytics(
166-
sagemaker_session=sagemaker_session, search_expression=search_exp, sort_by="Parameters.hp1"
167-
)
168-
169-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
170-
assert (
171-
len(analytics.dataframe()) > 5
172-
) # TODO [owen-t] Replace with == 10 and put test in retry block
173-
assert (
174-
list(analytics.dataframe()["hp1"].values)
175-
== sorted(analytics.dataframe()["hp1"].values)[::-1]
176-
)
152+
@contextmanager
153+
def _ignore_resource_not_found(sagemaker_client):
154+
try:
155+
yield
156+
except sagemaker_client.exceptions.ResourceNotFound:
157+
pass

0 commit comments

Comments
 (0)