Skip to content

Commit 29d5bb8

Browse files
committed
various fixes
Fixed saving in training, changed format of output a bit, added warning in pad for values slightly above power of two + updated test
1 parent 0c16964 commit 29d5bb8

File tree

13 files changed

+87
-48
lines changed

13 files changed

+87
-48
lines changed

src/napari_cellseg3d/_tests/test_dock_widget.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
def test_prepare(make_napari_viewer):
10-
path_to_csv = Path(os.path.dirname(os.path.realpath(__file__)) + "/res")
1110
path_image = Path(
1211
os.path.dirname(os.path.realpath(__file__)) + "/res/test.tif"
1312
)
@@ -16,7 +15,7 @@ def test_prepare(make_napari_viewer):
1615
viewer.add_image(image)
1716
widget = Datamanager(viewer)
1817

19-
widget.prepare(path_to_csv, ".tif", "", False, False)
18+
widget.prepare(path_image, ".tif", "", False, False)
2019

2120
assert widget.filetype == ".tif"
2221
assert widget.as_folder == False

src/napari_cellseg3d/_tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ def test_get_padding_dim():
6767

6868
assert pad == [2048, 32, 64]
6969

70+
tensor = torch.randn(65,70,80)
71+
size = tensor.size()
72+
73+
pad = utils.get_padding_dim(size)
74+
75+
assert pad == [128,128,128]
76+
7077

7178
def test_normalize_x():
7279

src/napari_cellseg3d/launch_review.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def quicksave():
187187
# quicksave()
188188
# viewer.window.close()
189189

190-
return dirname, quicksave() #, quicksave_quit()
190+
return dirname, quicksave() # , quicksave_quit()
191191

192192
# gui = file_widget.show(run=True) # dirpicker.show(run=True)
193193

src/napari_cellseg3d/log_utility.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def __init__(self, parent):
2020

2121
self.lock = threading.Lock()
2222

23-
def receive_log(self, text):
24-
self.print_and_log(text)
23+
# def receive_log(self, text):
24+
# self.print_and_log(text)
2525

2626
def print_and_log(self, text):
2727
"""Utility used to both print to terminal and log text to a QTextEdit
@@ -31,10 +31,14 @@ def print_and_log(self, text):
3131
text (str): Text to be printed and logged
3232
3333
"""
34-
with self.lock:
34+
self.lock.acquire()
35+
try:
3536
print(text)
37+
# causes issues if you clik on terminal (tied to CMD QuickEdit mode)
3638
self.moveCursor(QTextCursor.End)
3739
self.insertPlainText(f"\n{text}")
3840
self.verticalScrollBar().setValue(
3941
self.verticalScrollBar().maximum()
4042
)
43+
finally:
44+
self.lock.release()

src/napari_cellseg3d/model_framework.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,22 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
133133
def send_log(self, text):
134134
self.log.print_and_log(text)
135135

136-
def save_log(self):
137-
"""Saves the worker's log to disk at self.results_path when called"""
136+
def save_log(self, spec_path=None):
137+
"""Saves the worker's log to disk at self.results_path when called
138+
139+
Args:
140+
spec_path: if specified, saves to path instead of self.results_path
141+
"""
138142
log = self.log.toPlainText()
139143

144+
if spec_path is None:
145+
path = self.results_path
146+
else:
147+
path = spec_path
148+
140149
if len(log) != 0:
141150
with open(
142-
self.results_path + f"/Log_report_{utils.get_date_time()}.txt",
151+
path + f"/Log_report_{utils.get_date_time()}.txt",
143152
"x",
144153
) as f:
145154
f.write(log)

src/napari_cellseg3d/model_workers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,13 @@ def inference(self):
200200
self.log("\nChecking dimensions...")
201201
pad = utils.get_padding_dim(check)
202202
# print(pad)
203-
dims =128
204-
dims=64
203+
dims = 128
204+
dims = 64
205205

206206
model = self.model_dict["class"].get_net()
207207
if self.model_dict["name"] == "SegResNet":
208208
model = self.model_dict["class"].get_net()(
209-
input_image_size=[dims,dims,dims], # TODO FIX !
209+
input_image_size=[dims, dims, dims], # TODO FIX !
210210
out_channels=1,
211211
# dropout_prob=0.3,
212212
)
@@ -511,7 +511,7 @@ def train(self):
511511
512512
* data_dicts : dict from :py:func:`Trainer.create_train_dataset_dict`
513513
514-
* max_epochs : the amout of epochs to train for
514+
* max_epochs : the amount of epochs to train for
515515
516516
* loss_function : the loss function to use for training
517517
@@ -757,7 +757,7 @@ def train(self):
757757
f"* {step}/{len(train_ds) // train_loader.batch_size}, "
758758
f"Train loss: {loss.detach().item():.4f}"
759759
)
760-
yield {"plot":False, "weights": model.state_dict()}
760+
yield {"plot": False, "weights": model.state_dict()}
761761

762762
epoch_loss /= step
763763
epoch_loss_values.append(epoch_loss)
@@ -820,11 +820,12 @@ def train(self):
820820
if metric > best_metric:
821821
best_metric = metric
822822
best_metric_epoch = epoch + 1
823+
self.log("Saving best metric model")
823824
torch.save(
824825
model.state_dict(),
825826
os.path.join(self.results_path, weights_filename),
826827
)
827-
self.log("Saved best metric model")
828+
self.log("Saving complete")
828829
self.log(
829830
f"Current epoch: {epoch + 1}, Current mean dice: {metric:.4f}"
830831
f"\nBest mean dice: {best_metric:.4f} "

src/napari_cellseg3d/plugin_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def build(self):
164164
ui.add_blank(layout=layout, widget=self)
165165
layout.addWidget(self.lbl_error)
166166

167-
ui.make_scrollable(layout, self, min_wh=[230, 300], base_wh=[230, 350])
167+
ui.make_scrollable(layout, self, min_wh=[230, 400], base_wh=[230, 450])
168168

169169
def folder_to_semantic(self):
170170
"""Converts folder of labels to semantic labels"""

src/napari_cellseg3d/plugin_crop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,14 @@ def add_crop_sliders(
283283
print(f"Crop variables")
284284
print(image_stack.shape)
285285

286-
287-
288286
# define crop sizes and boundaries for the image
289287
crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z]
290288
for i in range(len(crop_sizes)):
291289
if crop_sizes[i] > image_stack.shape[i]:
292290
crop_sizes[i] = image_stack.shape[i]
293-
warnings.warn(f"WARNING : Crop dimension in axis {i} was too large at {crop_sizes[i]}, it was set to {image_stack.shape[i]}")
291+
warnings.warn(
292+
f"WARNING : Crop dimension in axis {i} was too large at {crop_sizes[i]}, it was set to {image_stack.shape[i]}"
293+
)
294294
cropx, cropy, cropz = crop_sizes
295295
# shapez, shapey, shapex = image_stack.shape
296296
ends = np.asarray(image_stack.shape) - np.asarray(crop_sizes) + 1

src/napari_cellseg3d/plugin_dock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def create(self, label_dir, model_type, filename=None):
162162
if self.filename is not None:
163163
filename = self.filename
164164
else:
165-
filename="image"
165+
filename = "image"
166166
labels = [str(filename) for i in range(self.image_dims[0])]
167167

168168
df = pd.DataFrame(

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from napari_cellseg3d.model_workers import TrainingWorker
3333

3434
NUMBER_TABS = 3
35+
DEFAULT_PATCH_SIZE = 60
3536

3637

3738
class Trainer(ModelFramework):
@@ -116,6 +117,7 @@ def __init__(
116117
self.data_path = ""
117118
self.label_path = ""
118119
self.results_path = ""
120+
self.results_path_folder = ""
119121
######################
120122
######################
121123
######################
@@ -252,7 +254,9 @@ def __init__(
252254
]
253255
"""Close buttons list for each tab"""
254256

255-
self.patch_size_widgets = ui.make_n_spinboxes(3, 10, 1023, 120)
257+
self.patch_size_widgets = ui.make_n_spinboxes(
258+
3, 10, 1024, DEFAULT_PATCH_SIZE
259+
)
256260

257261
self.patch_size_lbl = [
258262
QLabel(f"Size of patch in {axis} :") for axis in "xyz"
@@ -651,23 +655,24 @@ def build(self):
651655
ui.make_scrollable(
652656
contained_layout=data_tab_layout,
653657
containing_widget=data_tab,
654-
min_wh=[100, 150],
658+
min_wh=[200, 300],
655659
) # , max_wh=[200,1000])
656660
self.addTab(data_tab, "Data")
657661

658662
ui.make_scrollable(
659663
contained_layout=augment_tab_l,
660664
containing_widget=augment_tab_w,
661-
min_wh=[100, 200],
665+
min_wh=[200, 300],
662666
)
663667
self.addTab(augment_tab_w, "Augmentation")
664668

665669
ui.make_scrollable(
666670
contained_layout=train_tab_layout,
667671
containing_widget=train_tab,
668-
min_wh=[100, 200],
672+
min_wh=[200, 300],
669673
)
670674
self.addTab(train_tab, "Training")
675+
self.setMinimumSize(220, 200)
671676

672677
def show_dialog_lab(self):
673678
"""Shows the dialog to load label files in a path, loads them (see :doc:model_framework) and changes the widget
@@ -760,12 +765,12 @@ def start(self):
760765
"class": self.get_model(self.model_choice.currentText()),
761766
"name": self.model_choice.currentText(),
762767
}
763-
self.results_path = (
768+
self.results_path_folder = (
764769
self.results_path
765770
+ f"/{model_dict['name']}_results_{utils.get_date_time()}"
766771
)
767772
os.makedirs(
768-
self.results_path, exist_ok=False
773+
self.results_path_folder, exist_ok=False
769774
) # avoid overwrite where possible
770775

771776
if self.use_transfer_choice.isChecked():
@@ -777,7 +782,7 @@ def start(self):
777782
weights_path = None
778783

779784
self.log.print_and_log(
780-
f"Notice : Saving results to : {self.results_path}"
785+
f"Notice : Saving results to : {self.results_path_folder}"
781786
)
782787

783788
self.worker = TrainingWorker(
@@ -790,7 +795,7 @@ def start(self):
790795
learning_rate=self.learning_rate,
791796
val_interval=self.val_interval,
792797
batch_size=self.batch_size,
793-
results_path=self.results_path,
798+
results_path=self.results_path_folder,
794799
sampling=self.patch_choice.isChecked(),
795800
num_samples=self.num_samples,
796801
sample_size=self.patch_size,
@@ -812,11 +817,13 @@ def start(self):
812817
self.worker.errored.connect(self.on_error)
813818

814819
if self.worker.is_running:
820+
self.log.print_and_log("*" * 20)
815821
self.log.print_and_log(
816-
f"Stop requested at {utils.get_time()}. \nWaiting for next validation step..."
822+
f"Stop requested at {utils.get_time()}. \nWaiting for next yielding step..."
817823
)
818824
self.stop_requested = True
819-
self.btn_start.setText("Stopping... Please wait for next saving")
825+
self.btn_start.setText("Stopping... Please wait")
826+
self.log.print_and_log("*" * 20)
820827
self.worker.quit()
821828
else:
822829
self.worker.start()
@@ -833,32 +840,35 @@ def on_start(self):
833840

834841
def on_finish(self):
835842
"""Catches finished signal from worker"""
843+
self.log.print_and_log("*" * 20)
836844
self.log.print_and_log(f"\nWorker finished at {utils.get_time()}")
837845

838-
self.log.print_and_log(f"Saving last loss plot at {self.results_path}")
846+
self.log.print_and_log(f"Saving in {self.results_path_folder}")
847+
self.log.print_and_log(f"Saving last loss plot")
839848

840849
if self.canvas is not None:
841850
self.canvas.figure.savefig(
842851
(
843-
self.results_path
852+
self.results_path_folder
844853
+ f"/final_metric_plots_{utils.get_time_filepath()}.png"
845854
),
846855
format="png",
847856
)
848857

849-
self.log.print_and_log("Auto-saving log")
850-
self.save_log()
858+
self.log.print_and_log("Saving log")
859+
self.save_log(spec_path=self.results_path_folder)
851860

852861
self.log.print_and_log("Done")
853862
self.log.print_and_log("*" * 10)
854863

855864
self.btn_start.setText("Start")
856865
[btn.setVisible(True) for btn in self.close_buttons]
857866

867+
del self.worker
858868
self.worker = None
869+
self.results_path_folder = ""
859870
self.empty_cuda_cache()
860871

861-
self.results_path = ""
862872
# self.clean_cache() # trying to fix memory leak
863873

864874
def on_error(self):
@@ -880,13 +890,17 @@ def on_yield(data, widget):
880890
widget.update_loss_plot(data["losses"], data["val_metrics"])
881891

882892
if widget.stop_requested:
893+
widget.log.print_and_log(
894+
"Saving weights from aborted training in results folder"
895+
)
883896
torch.save(
884897
data["weights"],
885898
os.path.join(
886-
widget.results_path,
899+
widget.results_path_folder,
887900
f"latest_weights_aborted_training_{utils.get_time_filepath()}.pth",
888901
),
889902
)
903+
widget.log.print_and_log("Saving complete")
890904
widget.stop_requested = False
891905

892906
# def clean_cache(self):
@@ -942,7 +956,7 @@ def plot_loss(self, loss, dice_metric):
942956
)
943957
self.canvas.draw_idle()
944958

945-
plot_path = self.results_path + "/Loss_plots"
959+
plot_path = self.results_path_folder + "/Loss_plots"
946960
os.makedirs(plot_path, exist_ok=True)
947961
if self.canvas is not None:
948962
self.canvas.figure.savefig(

0 commit comments

Comments
 (0)