|
1 | 1 | from __future__ import absolute_import
|
2 | 2 |
|
3 |
| -import uuid |
4 | 3 | import time
|
| 4 | +import uuid |
| 5 | +from contextlib import contextmanager |
5 | 6 |
|
6 | 7 | import pytest
|
7 | 8 |
|
8 | 9 | from sagemaker.analytics import ExperimentAnalytics
|
9 | 10 |
|
10 | 11 |
|
11 |
| -@pytest.mark.canary_quick |
12 |
| -def test_experiment_analytics(sagemaker_session): |
| 12 | +@contextmanager |
| 13 | +def experiment(sagemaker_session): |
13 | 14 | sm = sagemaker_session.sagemaker_client
|
| 15 | + trials = {} # for resource cleanup |
14 | 16 |
|
15 | 17 | experiment_name = "experiment-" + str(uuid.uuid4())
|
16 |
| - sm.create_experiment(ExperimentName=experiment_name) |
17 |
| - |
18 |
| - for i in range(5): |
19 |
| - trial_name = "trial-" + str(uuid.uuid4()) |
20 |
| - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
21 |
| - trial_component_name = "tc-" + str(uuid.uuid4()) |
22 |
| - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
23 |
| - sm.update_trial_component( |
24 |
| - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
25 |
| - ) |
26 |
| - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
27 |
| - |
28 |
| - time.sleep(15) # wait for search to get updated |
29 |
| - |
30 |
| - analytics = ExperimentAnalytics( |
31 |
| - experiment_name=experiment_name, sagemaker_session=sagemaker_session |
32 |
| - ) |
| 18 | + try: |
| 19 | + sm.create_experiment(ExperimentName=experiment_name) |
33 | 20 |
|
34 |
| - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 21 | + # Search returns 10 results by default. Add 20 trials to verify pagination. |
| 22 | + for i in range(20): |
| 23 | + trial_name = "trial-" + str(uuid.uuid4()) |
| 24 | + sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
35 | 25 |
|
| 26 | + trial_component_name = "tc-" + str(uuid.uuid4()) |
| 27 | + trials[trial_name] = trial_component_name |
36 | 28 |
|
37 |
| -def test_experiment_analytics_pagination(sagemaker_session): |
38 |
| - sm = sagemaker_session.sagemaker_client |
39 |
| - |
40 |
| - experiment_name = "experiment" + str(uuid.uuid4()) |
41 |
| - sm.create_experiment(ExperimentName=experiment_name) |
42 |
| - |
43 |
| - # Search returns 10 results by default. Add 20 trials to verify pagination, |
44 |
| - for i in range(20): |
45 |
| - trial_name = "trial-" + str(uuid.uuid4()) |
46 |
| - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
47 |
| - trial_component_name = "tc-" + str(uuid.uuid4()) |
48 |
| - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
49 |
| - sm.update_trial_component( |
50 |
| - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
51 |
| - ) |
52 |
| - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
| 29 | + sm.create_trial_component( |
| 30 | + TrialComponentName=trial_component_name, DisplayName="Training" |
| 31 | + ) |
| 32 | + sm.update_trial_component( |
| 33 | + TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
| 34 | + ) |
| 35 | + sm.associate_trial_component( |
| 36 | + TrialComponentName=trial_component_name, TrialName=trial_name |
| 37 | + ) |
53 | 38 |
|
54 |
| - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
| 39 | + time.sleep(15) # wait for search to get updated |
55 | 40 |
|
56 |
| - analytics = ExperimentAnalytics( |
57 |
| - experiment_name=experiment_name, sagemaker_session=sagemaker_session |
58 |
| - ) |
| 41 | + yield experiment_name |
| 42 | + finally: |
| 43 | + _delete_resources(sm, experiment_name, trials) |
59 | 44 |
|
60 |
| - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
61 |
| - assert ( |
62 |
| - len(analytics.dataframe()) > 10 |
63 |
| - ) # TODO [owen-t] Replace with == 20 and put test in retry block |
64 | 45 |
|
| 46 | +@pytest.mark.canary_quick |
| 47 | +def test_experiment_analytics(sagemaker_session): |
| 48 | + with experiment(sagemaker_session) as experiment_name: |
| 49 | + analytics = ExperimentAnalytics( |
| 50 | + experiment_name=experiment_name, sagemaker_session=sagemaker_session |
| 51 | + ) |
65 | 52 |
|
66 |
| -def test_experiment_analytics_search_by_nested_filter(sagemaker_session): |
67 |
| - sm = sagemaker_session.sagemaker_client |
| 53 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
68 | 54 |
|
69 |
| - experiment_name = "experiment" + str(uuid.uuid4()) |
70 |
| - sm.create_experiment(ExperimentName=experiment_name) |
71 | 55 |
|
72 |
| - for i in range(20): |
73 |
| - trial_name = "trial-" + str(uuid.uuid4()) |
74 |
| - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
75 |
| - trial_component_name = "tc-" + str(uuid.uuid4()) |
76 |
| - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
77 |
| - sm.update_trial_component( |
78 |
| - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
| 56 | +def test_experiment_analytics_pagination(sagemaker_session): |
| 57 | + with experiment(sagemaker_session) as experiment_name: |
| 58 | + analytics = ExperimentAnalytics( |
| 59 | + experiment_name=experiment_name, sagemaker_session=sagemaker_session |
79 | 60 | )
|
80 |
| - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
81 | 61 |
|
82 |
| - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
| 62 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 63 | + assert ( |
| 64 | + len(analytics.dataframe()) > 10 |
| 65 | + ) # TODO [owen-t] Replace with == 20 and put test in retry block |
83 | 66 |
|
84 |
| - search_exp = { |
85 |
| - "Filters": [ |
86 |
| - {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
87 |
| - {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
88 |
| - ] |
89 |
| - } |
90 | 67 |
|
91 |
| - analytics = ExperimentAnalytics( |
92 |
| - sagemaker_session=sagemaker_session, search_expression=search_exp |
93 |
| - ) |
| 68 | +def test_experiment_analytics_search_by_nested_filter(sagemaker_session): |
| 69 | + with experiment(sagemaker_session) as experiment_name: |
| 70 | + search_exp = { |
| 71 | + "Filters": [ |
| 72 | + {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
| 73 | + {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
| 74 | + ] |
| 75 | + } |
| 76 | + |
| 77 | + analytics = ExperimentAnalytics( |
| 78 | + sagemaker_session=sagemaker_session, search_expression=search_exp |
| 79 | + ) |
94 | 80 |
|
95 |
| - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
96 |
| - assert ( |
97 |
| - len(analytics.dataframe()) > 5 |
98 |
| - ) # TODO [owen-t] Replace with == 10 and put test in retry block |
| 81 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 82 | + assert ( |
| 83 | + len(analytics.dataframe()) > 5 |
| 84 | + ) # TODO [owen-t] Replace with == 10 and put test in retry block |
99 | 85 |
|
100 | 86 |
|
101 | 87 | def test_experiment_analytics_search_by_nested_filter_sort_ascending(sagemaker_session):
|
102 |
| - sm = sagemaker_session.sagemaker_client |
| 88 | + with experiment(sagemaker_session) as experiment_name: |
| 89 | + search_exp = { |
| 90 | + "Filters": [ |
| 91 | + {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
| 92 | + {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
| 93 | + ] |
| 94 | + } |
| 95 | + |
| 96 | + analytics = ExperimentAnalytics( |
| 97 | + sagemaker_session=sagemaker_session, |
| 98 | + search_expression=search_exp, |
| 99 | + sort_by="Parameters.hp1", |
| 100 | + sort_order="Ascending", |
| 101 | + ) |
| 102 | + |
| 103 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 104 | + assert ( |
| 105 | + len(analytics.dataframe()) > 5 |
| 106 | + ) # TODO [owen-t] Replace with == 10 and put test in retry block |
| 107 | + assert list(analytics.dataframe()["hp1"].values) == sorted( |
| 108 | + analytics.dataframe()["hp1"].values |
| 109 | + ) |
103 | 110 |
|
104 |
| - experiment_name = "experiment" + str(uuid.uuid4()) |
105 |
| - sm.create_experiment(ExperimentName=experiment_name) |
106 | 111 |
|
107 |
| - for i in range(20): |
108 |
| - trial_name = "trial-" + str(uuid.uuid4()) |
109 |
| - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
110 |
| - trial_component_name = "tc-" + str(uuid.uuid4()) |
111 |
| - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
112 |
| - sm.update_trial_component( |
113 |
| - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
| 112 | +def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_session): |
| 113 | + with experiment(sagemaker_session) as experiment_name: |
| 114 | + search_exp = { |
| 115 | + "Filters": [ |
| 116 | + {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
| 117 | + {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
| 118 | + ] |
| 119 | + } |
| 120 | + |
| 121 | + analytics = ExperimentAnalytics( |
| 122 | + sagemaker_session=sagemaker_session, |
| 123 | + search_expression=search_exp, |
| 124 | + sort_by="Parameters.hp1", |
114 | 125 | )
|
115 |
| - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
116 | 126 |
|
117 |
| - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
| 127 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 128 | + assert ( |
| 129 | + len(analytics.dataframe()) > 5 |
| 130 | + ) # TODO [owen-t] Replace with == 10 and put test in retry block |
| 131 | + assert ( |
| 132 | + list(analytics.dataframe()["hp1"].values) |
| 133 | + == sorted(analytics.dataframe()["hp1"].values)[::-1] |
| 134 | + ) |
118 | 135 |
|
119 |
| - search_exp = { |
120 |
| - "Filters": [ |
121 |
| - {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
122 |
| - {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
123 |
| - ] |
124 |
| - } |
125 | 136 |
|
126 |
| - analytics = ExperimentAnalytics( |
127 |
| - sagemaker_session=sagemaker_session, |
128 |
| - search_expression=search_exp, |
129 |
| - sort_by="Parameters.hp1", |
130 |
| - sort_order="Ascending", |
131 |
| - ) |
| 137 | +def _delete_resources(sagemaker_client, experiment_name, trials): |
| 138 | + for trial, tc in trials.items(): |
| 139 | + with _ignore_resource_not_found(sagemaker_client): |
| 140 | + sagemaker_client.disassociate_trial_component(TrialName=trial, TrialComponentName=tc) |
132 | 141 |
|
133 |
| - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
134 |
| - assert ( |
135 |
| - len(analytics.dataframe()) > 5 |
136 |
| - ) # TODO [owen-t] Replace with == 10 and put test in retry block |
137 |
| - assert list(analytics.dataframe()["hp1"].values) == sorted(analytics.dataframe()["hp1"].values) |
| 142 | + with _ignore_resource_not_found(sagemaker_client): |
| 143 | + sagemaker_client.delete_trial_component(TrialComponentName=tc) |
138 | 144 |
|
| 145 | + with _ignore_resource_not_found(sagemaker_client): |
| 146 | + sagemaker_client.delete_trial(TrialName=trial) |
139 | 147 |
|
140 |
| -def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_session): |
141 |
| - sm = sagemaker_session.sagemaker_client |
| 148 | + with _ignore_resource_not_found(sagemaker_client): |
| 149 | + sagemaker_client.delete_experiment(ExperimentName=experiment_name) |
142 | 150 |
|
143 |
| - experiment_name = "experiment" + str(uuid.uuid4()) |
144 |
| - sm.create_experiment(ExperimentName=experiment_name) |
145 | 151 |
|
146 |
| - for i in range(20): |
147 |
| - trial_name = "trial-" + str(uuid.uuid4()) |
148 |
| - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
149 |
| - trial_component_name = "tc-" + str(uuid.uuid4()) |
150 |
| - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
151 |
| - sm.update_trial_component( |
152 |
| - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
153 |
| - ) |
154 |
| - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
155 |
| - |
156 |
| - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
157 |
| - |
158 |
| - search_exp = { |
159 |
| - "Filters": [ |
160 |
| - {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
161 |
| - {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
162 |
| - ] |
163 |
| - } |
164 |
| - |
165 |
| - analytics = ExperimentAnalytics( |
166 |
| - sagemaker_session=sagemaker_session, search_expression=search_exp, sort_by="Parameters.hp1" |
167 |
| - ) |
168 |
| - |
169 |
| - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
170 |
| - assert ( |
171 |
| - len(analytics.dataframe()) > 5 |
172 |
| - ) # TODO [owen-t] Replace with == 10 and put test in retry block |
173 |
| - assert ( |
174 |
| - list(analytics.dataframe()["hp1"].values) |
175 |
| - == sorted(analytics.dataframe()["hp1"].values)[::-1] |
176 |
| - ) |
| 152 | +@contextmanager |
| 153 | +def _ignore_resource_not_found(sagemaker_client): |
| 154 | + try: |
| 155 | + yield |
| 156 | + except sagemaker_client.exceptions.ResourceNotFound: |
| 157 | + pass |
0 commit comments