Skip to content

Commit 73c1093

Browse files
committed
add delete_all() method under experiment
1 parent 4df3bd6 commit 73c1093

File tree

4 files changed

+203
-2
lines changed

4 files changed

+203
-2
lines changed

src/smexperiments/experiment.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313

1414
"""Contains the SageMaker Experiment class."""
15-
from smexperiments import _base_types, api_types, trial, _utils
15+
from smexperiments import _base_types, api_types, trial, _utils, trial_component
16+
import time
1617

1718

1819
class Experiment(_base_types.Record):
@@ -227,3 +228,44 @@ def create_trial(self, trial_name=None, trial_name_prefix="SageMakerTrial"):
227228
experiment_name=self.experiment_name,
228229
sagemaker_boto_client=self.sagemaker_boto_client,
229230
)
231+
232+
def delete_all(self, action):
233+
"""
234+
Force to delete the experiment and associated trials, trial components under the experiment.
235+
236+
Args:
237+
action (str): pass in string '--force' to confirm recursively delete all the experiments, trials,
238+
and trial components.
239+
"""
240+
if action != "--force":
241+
raise ValueError(
242+
"Must confirm with string '--force' in order to delete the experiment and "
243+
"associated trials, trial components."
244+
)
245+
246+
delete_count = 0
247+
last_exception_message = None
248+
while True:
249+
if delete_count == 3:
250+
raise Exception("Fail to delete because" + last_exception_message + ", please try again.")
251+
try:
252+
for trial_summary in self.list_trials():
253+
t = trial.Trial.load(
254+
sagemaker_boto_client=self.sagemaker_boto_client, trial_name=trial_summary.trial_name
255+
)
256+
for trial_component_summary in t.list_trial_components():
257+
tc = trial_component.TrialComponent.load(
258+
sagemaker_boto_client=self.sagemaker_boto_client,
259+
trial_component_name=trial_component_summary.trial_component_name,
260+
)
261+
tc.delete(force_disassociate=True)
262+
t.remove_trial_component(tc)
263+
# to prevent throttling
264+
time.sleep(0.2)
265+
t.delete()
266+
self.delete()
267+
break
268+
except Exception as ex:
269+
last_exception_message = ex
270+
finally:
271+
delete_count = delete_count + 1

tests/conftest.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,75 @@ def tempdir():
9494
def experiment_obj(sagemaker_boto_client):
9595
description = "{}-{}".format("description", str(uuid.uuid4()))
9696
boto3.set_stream_logger("", logging.INFO)
97+
experiment_name = name()
9798
experiment_obj = experiment.Experiment.create(
98-
experiment_name=name(), description=description, sagemaker_boto_client=sagemaker_boto_client, tags=TAGS
99+
experiment_name=experiment_name, description=description, sagemaker_boto_client=sagemaker_boto_client, tags=TAGS
99100
)
100101
yield experiment_obj
101102
time.sleep(0.5)
102103
experiment_obj.delete()
104+
with pytest.raises(sagemaker_boto_client.exceptions.ResourceNotFound):
105+
sagemaker_boto_client.describe_experiment(ExperimentName=experiment_name)
106+
107+
108+
@pytest.fixture
109+
def complex_experiment_obj(sagemaker_boto_client):
110+
description = "{}-{}".format("description", str(uuid.uuid4()))
111+
boto3.set_stream_logger("", logging.INFO)
112+
113+
# create experiment
114+
experiment_obj_name = name()
115+
experiment_obj = experiment.Experiment.create(
116+
experiment_name=experiment_obj_name, description=description, sagemaker_boto_client=sagemaker_boto_client
117+
)
118+
119+
# create trials
120+
trial_objs = []
121+
trial_name1 = name()
122+
trial_name2 = name()
123+
trial_name3 = name()
124+
125+
next_trial1 = trial.Trial.create(
126+
trial_name=trial_name1, experiment_name=experiment_obj_name, sagemaker_boto_client=sagemaker_boto_client,
127+
)
128+
trial_objs.append(next_trial1)
129+
next_trial2 = trial.Trial.create(
130+
trial_name=trial_name2, experiment_name=experiment_obj_name, sagemaker_boto_client=sagemaker_boto_client,
131+
)
132+
trial_objs.append(next_trial2)
133+
next_trial3 = trial.Trial.create(
134+
trial_name=trial_name3, experiment_name=experiment_obj_name, sagemaker_boto_client=sagemaker_boto_client,
135+
)
136+
trial_objs.append(next_trial3)
137+
138+
# create trial components
139+
trial_component_name = name()
140+
trial_component_obj = trial_component.TrialComponent.create(
141+
trial_component_name=trial_component_name, sagemaker_boto_client=sagemaker_boto_client,
142+
)
143+
144+
# associate the trials with trial components
145+
for trial_obj in trial_objs:
146+
sagemaker_boto_client.associate_trial_component(
147+
TrialName=trial_obj.trial_name, TrialComponentName=trial_component_obj.trial_component_name
148+
)
149+
time.sleep(1.0)
150+
yield experiment_obj
151+
experiment_obj.delete_all(action="--force")
152+
153+
# load experiment and verify experiment got deleted
154+
with pytest.raises(sagemaker_boto_client.exceptions.ResourceNotFound):
155+
sagemaker_boto_client.describe_experiment(ExperimentName=experiment_obj_name)
156+
# load trials and verify trials got deleted
157+
with pytest.raises(sagemaker_boto_client.exceptions.ResourceNotFound):
158+
sagemaker_boto_client.describe_trial(TrialName=trial_name1)
159+
with pytest.raises(sagemaker_boto_client.exceptions.ResourceNotFound):
160+
sagemaker_boto_client.describe_trial(TrialName=trial_name2)
161+
with pytest.raises(sagemaker_boto_client.exceptions.ResourceNotFound):
162+
sagemaker_boto_client.describe_trial(TrialName=trial_name3)
163+
# load trial component and verify trial component got deleted
164+
with pytest.raises(sagemaker_boto_client.exceptions.ResourceNotFound):
165+
sagemaker_boto_client.describe_trial_component(TrialComponentName=trial_component_name)
103166

104167

105168
@pytest.fixture

tests/integ/test_experiment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
import datetime
14+
import pytest
1415

1516
from tests.helpers import name
1617
from smexperiments import experiment, trial
@@ -125,3 +126,12 @@ def test_list_trials(experiment_obj, trials):
125126
trial_names = [trial_obj.trial_name for trial_obj in trials]
126127
assert set(trial_names) == set([s.trial_name for s in experiment_obj.list_trials()])
127128
assert trial_names # sanity test
129+
130+
131+
def test_delete_all(complex_experiment_obj):
132+
assert complex_experiment_obj.experiment_name
133+
134+
135+
def test_delete_all_fails(experiment_obj):
136+
with pytest.raises(ValueError):
137+
experiment_obj.delete_all(action="test")

tests/unit/test_experiment.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,89 @@ def test_delete(sagemaker_boto_client):
210210
sagemaker_boto_client.delete_experiment.return_value = {}
211211
obj.delete()
212212
sagemaker_boto_client.delete_experiment.assert_called_with(ExperimentName="foo")
213+
214+
215+
def test_delete_all_with_incorrect_action_name(sagemaker_boto_client):
216+
obj = experiment.Experiment(sagemaker_boto_client, experiment_name="foo", description="bar")
217+
with pytest.raises(ValueError):
218+
obj.delete_all(action="abc")
219+
220+
221+
def test_delete_all(sagemaker_boto_client):
222+
obj = experiment.Experiment(sagemaker_boto_client, experiment_name="foo", description="bar")
223+
sagemaker_boto_client.list_trials.return_value = {
224+
"TrialSummaries": [
225+
{"TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
226+
{"TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
227+
]
228+
}
229+
sagemaker_boto_client.describe_trial.side_effect = [
230+
{"Trialname": "trial-1", "ExperimentName": "experiment-name-value"},
231+
{"Trialname": "trial-2", "ExperimentName": "experiment-name-value"},
232+
]
233+
sagemaker_boto_client.list_trial_components.side_effect = [
234+
{
235+
"TrialComponentSummaries": [
236+
{
237+
"TrialComponentName": "trial-component-1",
238+
"CreationTime": datetime_obj,
239+
"LastModifiedTime": datetime_obj,
240+
},
241+
{
242+
"TrialComponentName": "trial-component-2",
243+
"CreationTime": datetime_obj,
244+
"LastModifiedTime": datetime_obj,
245+
},
246+
]
247+
},
248+
{
249+
"TrialComponentSummaries": [
250+
{
251+
"TrialComponentName": "trial-component-3",
252+
"CreationTime": datetime_obj,
253+
"LastModifiedTime": datetime_obj,
254+
},
255+
{
256+
"TrialComponentName": "trial-component-4",
257+
"CreationTime": datetime_obj,
258+
"LastModifiedTime": datetime_obj,
259+
},
260+
]
261+
},
262+
]
263+
264+
sagemaker_boto_client.describe_trial_component.side_effect = [
265+
{"TrialComponentName": "trial-component-1"},
266+
{"TrialComponentName": "trial-component-2"},
267+
{"TrialComponentName": "trial-component-3"},
268+
{"TrialComponentName": "trial-component-4"},
269+
]
270+
271+
sagemaker_boto_client.delete_trial_component.return_value = {}
272+
sagemaker_boto_client.delete_trial.return_value = {}
273+
sagemaker_boto_client.delete_experiment.return_value = {}
274+
275+
obj.delete_all(action="--force")
276+
277+
sagemaker_boto_client.delete_experiment.assert_called_with(ExperimentName="foo")
278+
279+
delete_trial_expected_calls = [
280+
unittest.mock.call(TrialName="trial-1"),
281+
unittest.mock.call(TrialName="trial-2"),
282+
]
283+
assert delete_trial_expected_calls == sagemaker_boto_client.delete_trial.mock_calls
284+
285+
delete_trial_component_expected_calls = [
286+
unittest.mock.call(TrialComponentName="trial-component-1"),
287+
unittest.mock.call(TrialComponentName="trial-component-2"),
288+
unittest.mock.call(TrialComponentName="trial-component-3"),
289+
unittest.mock.call(TrialComponentName="trial-component-4"),
290+
]
291+
assert delete_trial_component_expected_calls == sagemaker_boto_client.delete_trial_component.mock_calls
292+
293+
294+
def test_delete_all_fail(sagemaker_boto_client):
295+
obj = experiment.Experiment(sagemaker_boto_client, experiment_name="foo", description="bar")
296+
sagemaker_boto_client.list_trials.side_effect = Exception
297+
with pytest.raises(Exception):
298+
obj.delete_all(action="--force")

0 commit comments

Comments
 (0)