Skip to content

Commit 9952f88

Browse files
committed
feature: add delete_all method to Trial
1 parent bd48bdc commit 9952f88

16 files changed

+303
-53
lines changed

scripts/version.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def parse(version):
4949
if not match:
5050
raise ValueError(f"invalid version: {version}")
5151

52-
return Version(int(match.group("major") or 0), int(match.group("minor") or 0), int(match.group("patch") or 0),)
52+
return Version(
53+
int(match.group("major") or 0),
54+
int(match.group("minor") or 0),
55+
int(match.group("patch") or 0),
56+
)
5357

5458

5559
def next_version_from_current_version(current_version, increment_type):

src/smexperiments/_base_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ApiObject(object):
2121
to dicts into/from a Python object with standard python members. Clients invoke to_boto on an instance
2222
of ApiObject to transform the ApiObject into a boto representation. Clients invoke from_boto on a sub-class of
2323
ApiObject to instantiate an instance of that class from a boto representation.
24-
"""
24+
"""
2525

2626
# A map from boto 'UpperCamelCase' name to member name. If a boto name does not appear in this dict then
2727
# it is converted to lower_snake_case.
@@ -81,7 +81,8 @@ def __hash__(self):
8181
def __repr__(self):
8282
"""Returns a string representation of this ApiObject."""
8383
return "{}({})".format(
84-
type(self).__name__, ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
84+
type(self).__name__,
85+
",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
8586
)
8687

8788

src/smexperiments/_boto_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type):
8383
def to_boto(member_vars, member_name_to_boto_name, member_name_to_type):
8484
"""Convert a dict of of snake case names to values into a boto UpperCamelCase representation.
8585
86-
Args:
87-
member_vars dict[str, ?]: A map from snake case name to value.
88-
member_name_to_boto_name dict[str, ?]: A map from snake_case name to boto name.
86+
Args:
87+
member_vars dict[str, ?]: A map from snake case name to value.
88+
member_name_to_boto_name dict[str, ?]: A map from snake_case name to boto name.
8989
90-
Returns:
91-
dict: boto dict converted to snake case
90+
Returns:
91+
dict: boto dict converted to snake case
9292
9393
"""
9494
to_boto_values = {}

src/smexperiments/experiment.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class Experiment(_base_types.Record):
5959
_boto_update_members = ["experiment_name", "description", "display_name"]
6060
_boto_delete_members = ["experiment_name"]
6161

62+
MAX_DELETE_ALL_ATTEMPTS = 3
63+
6264
def save(self):
6365
"""Save the state of this Experiment to SageMaker.
6466
@@ -91,7 +93,9 @@ def load(cls, experiment_name, sagemaker_boto_client=None):
9193
sagemaker.experiments.experiment.Experiment: A SageMaker ``Experiment`` object
9294
"""
9395
return cls._construct(
94-
cls._boto_load_method, experiment_name=experiment_name, sagemaker_boto_client=sagemaker_boto_client,
96+
cls._boto_load_method,
97+
experiment_name=experiment_name,
98+
sagemaker_boto_client=sagemaker_boto_client,
9599
)
96100

97101
@classmethod
@@ -119,7 +123,12 @@ def create(cls, experiment_name=None, description=None, tags=None, sagemaker_bot
119123

120124
@classmethod
121125
def list(
122-
cls, created_before=None, created_after=None, sort_by=None, sort_order=None, sagemaker_boto_client=None,
126+
cls,
127+
created_before=None,
128+
created_after=None,
129+
sort_by=None,
130+
sort_order=None,
131+
sagemaker_boto_client=None,
123132
):
124133
"""
125134
List experiments. Returns experiments in the account matching the specified criteria.
@@ -152,7 +161,12 @@ def list(
152161

153162
@classmethod
154163
def search(
155-
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
164+
cls,
165+
search_expression=None,
166+
sort_by=None,
167+
sort_order=None,
168+
max_results=None,
169+
sagemaker_boto_client=None,
156170
):
157171
"""
158172
Search experiments. Returns SearchResults in the account matching the search criteria.
@@ -243,11 +257,11 @@ def delete_all(self, action):
243257
"associated trials, trial components."
244258
)
245259

246-
delete_count = 0
260+
delete_attempt_count = 0
247261
last_exception = None
248262
while True:
249-
if delete_count == 3:
250-
raise Exception("Fail to delete, please try again.") from last_exception
263+
if delete_attempt_count == self.MAX_DELETE_ALL_ATTEMPTS:
264+
raise Exception("Failed to delete, please try again.") from last_exception
251265
try:
252266
for trial_summary in self.list_trials():
253267
t = trial.Trial.load(
@@ -269,4 +283,4 @@ def delete_all(self, action):
269283
except Exception as ex:
270284
last_exception = ex
271285
finally:
272-
delete_count = delete_count + 1
286+
delete_attempt_count = delete_attempt_count + 1

src/smexperiments/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,5 +137,6 @@ def __str__(self):
137137

138138
def __repr__(self):
139139
return "{}({})".format(
140-
type(self).__name__, ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
140+
type(self).__name__,
141+
",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
141142
)

src/smexperiments/tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
class Tracker(object):
27-
""""A SageMaker Experiments Tracker.
27+
"""A SageMaker Experiments Tracker.
2828
2929
Use a tracker object to record experiment information to a SageMaker trial component.
3030

src/smexperiments/training_job.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
class TrainingJob(_base_types.Record):
1919
@classmethod
2020
def search(
21-
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
21+
cls,
22+
search_expression=None,
23+
sort_by=None,
24+
sort_order=None,
25+
max_results=None,
26+
sagemaker_boto_client=None,
2227
):
2328
"""
2429
Search Training Job. Returns SearchResults in the account matching the search criteria.

src/smexperiments/trial.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Contains the Trial class."""
1414

1515
from smexperiments import api_types, _base_types, trial_component, _utils, tracker
16+
import time
1617

1718

1819
class Trial(_base_types.Record):
@@ -58,6 +59,8 @@ class Trial(_base_types.Record):
5859
_boto_update_members = ["trial_name", "display_name"]
5960
_boto_delete_members = ["trial_name"]
6061

62+
MAX_DELETE_ALL_ATTEMPTS = 3
63+
6164
@classmethod
6265
def _boto_ignore(cls):
6366
return super(Trial, cls)._boto_ignore() + ["CreatedBy"]
@@ -170,7 +173,12 @@ def list(
170173

171174
@classmethod
172175
def search(
173-
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
176+
cls,
177+
search_expression=None,
178+
sort_by=None,
179+
sort_order=None,
180+
max_results=None,
181+
sagemaker_boto_client=None,
174182
):
175183
"""
176184
Search experiments. Returns SearchResults in the account matching the search criteria.
@@ -270,3 +278,36 @@ def list_trial_components(
270278
next_token=next_token,
271279
sagemaker_boto_client=self.sagemaker_boto_client,
272280
)
281+
282+
def delete_all(self, action):
283+
"""
284+
Force to delete the trial and associated trial components under.
285+
286+
Args:
287+
action (str): pass in string '--force' to confirm delete the trial and all associated trial components.
288+
"""
289+
if action != "--force":
290+
raise ValueError(
291+
"Must confirm with string '--force' in order to delete the trial and " "associated trial components."
292+
)
293+
294+
delete_attempt_count = 0
295+
last_exception = None
296+
while True:
297+
if delete_attempt_count == self.MAX_DELETE_ALL_ATTEMPTS:
298+
raise Exception("Failed to delete, please try again.") from last_exception
299+
try:
300+
for trial_component_summary in self.list_trial_components():
301+
tc = trial_component.TrialComponent.load(
302+
sagemaker_boto_client=self.sagemaker_boto_client,
303+
trial_component_name=trial_component_summary.trial_component_name,
304+
)
305+
tc.delete(force_disassociate=True)
306+
# to prevent throttling
307+
time.sleep(1.2)
308+
self.delete()
309+
break
310+
except Exception as ex:
311+
last_exception = ex
312+
finally:
313+
delete_attempt_count = delete_attempt_count + 1

src/smexperiments/trial_component.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,12 @@ def list(
228228

229229
@classmethod
230230
def search(
231-
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
231+
cls,
232+
search_expression=None,
233+
sort_by=None,
234+
sort_order=None,
235+
max_results=None,
236+
sagemaker_boto_client=None,
232237
):
233238
"""
234239
Search experiments. Returns SearchResults in the account matching the search criteria.

tests/conftest.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,22 +123,29 @@ def complex_experiment_obj(sagemaker_boto_client):
123123
trial_name3 = name()
124124

125125
next_trial1 = trial.Trial.create(
126-
trial_name=trial_name1, experiment_name=experiment_obj_name, sagemaker_boto_client=sagemaker_boto_client,
126+
trial_name=trial_name1,
127+
experiment_name=experiment_obj_name,
128+
sagemaker_boto_client=sagemaker_boto_client,
127129
)
128130
trial_objs.append(next_trial1)
129131
next_trial2 = trial.Trial.create(
130-
trial_name=trial_name2, experiment_name=experiment_obj_name, sagemaker_boto_client=sagemaker_boto_client,
132+
trial_name=trial_name2,
133+
experiment_name=experiment_obj_name,
134+
sagemaker_boto_client=sagemaker_boto_client,
131135
)
132136
trial_objs.append(next_trial2)
133137
next_trial3 = trial.Trial.create(
134-
trial_name=trial_name3, experiment_name=experiment_obj_name, sagemaker_boto_client=sagemaker_boto_client,
138+
trial_name=trial_name3,
139+
experiment_name=experiment_obj_name,
140+
sagemaker_boto_client=sagemaker_boto_client,
135141
)
136142
trial_objs.append(next_trial3)
137143

138144
# create trial components
139145
trial_component_name = name()
140146
trial_component_obj = trial_component.TrialComponent.create(
141-
trial_component_name=trial_component_name, sagemaker_boto_client=sagemaker_boto_client,
147+
trial_component_name=trial_component_name,
148+
sagemaker_boto_client=sagemaker_boto_client,
142149
)
143150

144151
# associate the trials with trial components
@@ -181,7 +188,9 @@ def trial_obj(sagemaker_boto_client, experiment_obj):
181188
@pytest.fixture
182189
def trial_component_obj(sagemaker_boto_client):
183190
trial_component_obj = trial_component.TrialComponent.create(
184-
trial_component_name=name(), sagemaker_boto_client=sagemaker_boto_client, tags=TAGS,
191+
trial_component_name=name(),
192+
sagemaker_boto_client=sagemaker_boto_client,
193+
tags=TAGS,
185194
)
186195
yield trial_component_obj
187196
time.sleep(0.5)
@@ -346,7 +355,10 @@ def training_job_name(sagemaker_boto_client, training_role_arn, docker_image, tr
346355
"DataSource": {"S3DataSource": {"S3Uri": training_s3_uri, "S3DataType": "S3Prefix"}},
347356
}
348357
],
349-
AlgorithmSpecification={"TrainingImage": docker_image, "TrainingInputMode": "File",},
358+
AlgorithmSpecification={
359+
"TrainingImage": docker_image,
360+
"TrainingInputMode": "File",
361+
},
350362
RoleArn=training_role_arn,
351363
ResourceConfig={"InstanceType": "ml.m5.large", "InstanceCount": 1, "VolumeSizeInGB": 10},
352364
StoppingCondition={"MaxRuntimeInSeconds": 900},

0 commit comments

Comments
 (0)