Skip to content

Commit 5fee1ab

Browse files
committed
Add disassociate param under the TC delete method
1 parent 7a78b7f commit 5fee1ab

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

src/smexperiments/trial_component.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,39 @@ def save(self):
103103
"""Save the state of this TrialComponent to SageMaker."""
104104
return self._invoke_api(self._boto_update_method, self._boto_update_members)
105105

106-
def delete(self):
107-
"""Delete this TrialComponent from SageMaker."""
106+
def delete(self, force_disassociate=None):
107+
"""Delete this TrialComponent from SageMaker.
108+
109+
Args:
110+
force_disassociate (boolean): Indicates whether to force disassociate the trial component with the trials
111+
before deletion. If set to true, force disassociate the trial component with associated trials first, then
112+
delete the trial component. If it's not set or set to false, it will delete the trial component directory
113+
without disassociation.
114+
"""
115+
if force_disassociate:
116+
next_token = None
117+
118+
while True:
119+
if next_token:
120+
list_trials_response = self.sagemaker_boto_client.list_trials(
121+
TrialComponentName=self.trial_component_name, NextToken=next_token
122+
)
123+
else:
124+
list_trials_response = self.sagemaker_boto_client.list_trials(
125+
TrialComponentName=self.trial_component_name
126+
)
127+
128+
# Disassociate the trials and trial components
129+
for trial in list_trials_response["TrialSummaries"]:
130+
self.sagemaker_boto_client.disassociate_trial_component(
131+
TrialName=trial["TrialName"], TrialComponentName=self.trial_component_name
132+
)
133+
134+
if "NextToken" in list_trials_response:
135+
next_token = list_trials_response["NextToken"]
136+
else:
137+
break
138+
108139
self._invoke_api(self._boto_delete_method, self._boto_delete_members)
109140

110141
@classmethod

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ def trial_component_obj(sagemaker_boto_client):
125125
trial_component_obj.delete()
126126

127127

128+
@pytest.fixture
129+
def trial_component_with_force_disassociation_obj(trials, sagemaker_boto_client):
130+
trial_component_obj = trial_component.TrialComponent.create(
131+
trial_component_name=name(), sagemaker_boto_client=sagemaker_boto_client
132+
)
133+
for trial in trials:
134+
sagemaker_boto_client.associate_trial_component(
135+
TrialName=trial.trial_name, TrialComponentName=trial_component_obj.trial_component_name
136+
)
137+
yield trial_component_obj
138+
time.sleep(0.5)
139+
trial_component_obj.delete(force_disassociate=True)
140+
141+
128142
@pytest.fixture
129143
def trials(experiment_obj, sagemaker_boto_client):
130144
trial_objs = []

tests/integ/test_trial_component.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ def test_create_tags(trial_component_obj, sagemaker_boto_client):
3131
assert actual_tags == trial_component_obj.tags
3232

3333

34+
def test_delete_with_disassociate(trial_component_with_disassociation_obj, sagemaker_boto_client):
35+
assert trial_component_with_disassociation_obj.trial_component_name
36+
37+
38+
def test_delete_with_force_disassociate(trial_component_with_force_disassociation_obj, sagemaker_boto_client):
39+
assert trial_component_with_force_disassociation_obj.trial_component_name
40+
trials = sagemaker_boto_client.list_trials(
41+
TrialComponentName=trial_component_with_force_disassociation_obj.trial_component_name
42+
)["TrialSummaries"]
43+
assert len(trials) == 3
44+
45+
3446
def test_save(trial_component_obj, sagemaker_boto_client):
3547
trial_component_obj.display_name = str(uuid.uuid4())
3648
trial_component_obj.status = api_types.TrialComponentStatus(primary_status="InProgress", message="Message")

tests/unit/test_trial_component.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@ def test_delete(sagemaker_boto_client):
253253
sagemaker_boto_client.delete_trial_component.assert_called_with(TrialComponentName="foo")
254254

255255

256+
def test_delete_with_force_disassociate(sagemaker_boto_client):
257+
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar")
258+
sagemaker_boto_client.delete_trial_component.return_value = {}
259+
260+
sagemaker_boto_client.list_trials.side_effect = [
261+
{"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"},
262+
{"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]},
263+
]
264+
265+
obj.delete(force_disassociate=True)
266+
expected_calls = [
267+
unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"),
268+
unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"),
269+
unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"),
270+
unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"),
271+
]
272+
assert expected_calls == sagemaker_boto_client.disassociate_trial_component.mock_calls
273+
sagemaker_boto_client.delete_trial_component.assert_called_with(TrialComponentName="foo")
274+
275+
256276
def test_boto_ignore():
257277
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar")
258278
assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]

0 commit comments

Comments
 (0)