Skip to content

Commit fb1b130

Browse files
committed
Fix order for model deletion
1 parent d35da41 commit fb1b130

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

napari_cellseg3d/_tests/test_training.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from pathlib import Path
22

3+
import pytest
4+
35
from napari_cellseg3d._tests.fixtures import (
46
LogFixture,
57
LossFixture,
@@ -93,6 +95,22 @@ def test_unsupervised_training(make_napari_viewer_proxy):
9395
)
9496
assert isinstance(res, TrainingReport)
9597
assert not res.show_plot
98+
widget.worker._abort_requested = True
99+
res = next(
100+
widget.worker.train(
101+
provided_model=WNetFixture(),
102+
provided_optimizer=OptimizerFixture(),
103+
provided_loss=LossFixture(),
104+
)
105+
)
106+
assert isinstance(res, TrainingReport)
107+
assert not res.show_plot
108+
with pytest.raises(
109+
AttributeError,
110+
match="'WNetTrainingWorker' object has no attribute 'model'",
111+
):
112+
assert widget.worker.model is None
113+
96114
widget.worker.config.eval_volume_dict = [
97115
{"image": im_path_str, "label": im_path_str}
98116
]

napari_cellseg3d/code_models/worker_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@ def train(
559559
loss.backward(loss)
560560
optimizer.step()
561561

562+
yield TrainingReport(
563+
show_plot=False, weights=model.state_dict()
564+
)
565+
562566
if self._abort_requested:
563567
self.dataloader = None
564568
del self.dataloader
@@ -574,10 +578,6 @@ def train(
574578
del criterionW
575579
torch.cuda.empty_cache()
576580

577-
yield TrainingReport(
578-
show_plot=False, weights=model.state_dict()
579-
)
580-
581581
self.ncuts_losses.append(
582582
epoch_ncuts_loss / len(self.dataloader)
583583
)

0 commit comments

Comments
 (0)