File tree Expand file tree Collapse file tree 2 files changed +22
-4
lines changed
Expand file tree Collapse file tree 2 files changed +22
-4
lines changed Original file line number Diff line number Diff line change 11from pathlib import Path
22
3+ import pytest
4+
35from napari_cellseg3d ._tests .fixtures import (
46 LogFixture ,
57 LossFixture ,
@@ -93,6 +95,22 @@ def test_unsupervised_training(make_napari_viewer_proxy):
9395 )
9496 assert isinstance (res , TrainingReport )
9597 assert not res .show_plot
98+ widget .worker ._abort_requested = True
99+ res = next (
100+ widget .worker .train (
101+ provided_model = WNetFixture (),
102+ provided_optimizer = OptimizerFixture (),
103+ provided_loss = LossFixture (),
104+ )
105+ )
106+ assert isinstance (res , TrainingReport )
107+ assert not res .show_plot
108+ with pytest .raises (
109+ AttributeError ,
110+ match = "'WNetTrainingWorker' object has no attribute 'model'" ,
111+ ):
112+ assert widget .worker .model is None
113+
96114 widget .worker .config .eval_volume_dict = [
97115 {"image" : im_path_str , "label" : im_path_str }
98116 ]
Original file line number Diff line number Diff line change @@ -559,6 +559,10 @@ def train(
559559 loss .backward (loss )
560560 optimizer .step ()
561561
562+ yield TrainingReport (
563+ show_plot = False , weights = model .state_dict ()
564+ )
565+
562566 if self ._abort_requested :
563567 self .dataloader = None
564568 del self .dataloader
@@ -574,10 +578,6 @@ def train(
574578 del criterionW
575579 torch .cuda .empty_cache ()
576580
577- yield TrainingReport (
578- show_plot = False , weights = model .state_dict ()
579- )
580-
581581 self .ncuts_losses .append (
582582 epoch_ncuts_loss / len (self .dataloader )
583583 )
You can’t perform that action at this time.
0 commit comments