4040from tifffile import imwrite
4141
4242# local
43+ from napari_cellseg3d .model_instance_seg import (
44+ binary_watershed ,
45+ binary_connected ,
46+ )
4347from napari_cellseg3d import utils
4448
4549"""
@@ -64,6 +68,7 @@ class LogSignal(WorkerBaseSignals):
6468
6569 log_signal = Signal (str )
6670 """qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
71+
6772 # Should not be an instance variable but a class variable, not defined in __init__, see
6873 # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
6974
@@ -155,6 +160,9 @@ def log(self, text):
155160
156161 def log_parameters (self ):
157162
163+ self .log ("-" * 20 )
164+ self .log ("Parameters summary :" )
165+
158166 self .log (f"Model is : { self .model_dict ['name' ]} " )
159167 if self .transforms ["thresh" ][0 ]:
160168 self .log (
@@ -173,10 +181,19 @@ def log_parameters(self):
173181 else :
174182 self .log (f"Dataset loaded on { self .device } " )
175183
184+ if self .transforms ["zoom" ][0 ]:
185+ self .log (
186+ f"Anisotropy parameters are : { self .transforms ['zoom' ][1 ]} microns in x,y,z"
187+ )
188+
176189 if self .instance_params ["do_instance" ]:
177- # TODO move instance seg
178- self .log (f"Instance segmentation enabled" )
190+ self .log (
191+ f"Instance segmentation enabled, method : { self .instance_params ['method' ]} \n "
192+ f"Probability threshold is { self .instance_params ['threshold' ]:.2f} \n "
193+ f"Objects smaller than { self .instance_params ['size_small' ]} pixels will be removed"
194+ )
179195 # self.log(f"")
196+ self .log ("-" * 20 )
180197
181198 def inference (self ):
182199 """
@@ -234,8 +251,7 @@ def inference(self):
234251 self .log ("\n Checking dimensions..." )
235252 pad = utils .get_padding_dim (check )
236253 # print(pad)
237- dims = 128
238- # dims = 64 # TODO
254+ dims = self .model_dict ["segres_size" ]
239255
240256 model = self .model_dict ["class" ].get_net ()
241257 if self .model_dict ["name" ] == "SegResNet" :
@@ -304,7 +320,7 @@ def inference(self):
304320 for i , inf_data in enumerate (inference_loader ):
305321
306322 self .log ("-" * 10 )
307- self .log (f"Inference started on image n°{ i + 1 } ..." )
323+ self .log (f"Inference started on image n°{ i + 1 } ..." )
308324
309325 inputs = inf_data ["image" ]
310326 # print(inputs.shape)
@@ -350,6 +366,7 @@ def inference(self):
350366 out = post_process_transforms (out )
351367 out = np .array (out ).astype (np .float32 )
352368 out = np .squeeze (out )
369+ to_instance = out # avoid post processing since thresholding is done there anyway
353370
354371 # batch_len = out.shape[1]
355372 # print("trying to check len")
@@ -391,8 +408,31 @@ def inference(self):
391408 self .log (
392409 f"\n Running instance segmentation for image n°{ image_id } "
393410 )
394- method = self .instance_params ["method" ]
395- instance_labels = method (out )
411+
412+ threshold = self .instance_params ["threshold" ]
413+ size_small = self .instance_params ["size_small" ]
414+ method_name = self .instance_params ["method" ]
415+
416+ if method_name == "Watershed" :
417+
418+ def method (image ):
419+ return binary_watershed (
420+ image , threshold , size_small
421+ )
422+
423+ elif method_name == "Connected components" :
424+
425+ def method (image ):
426+ return binary_connected (
427+ image , threshold , size_small
428+ )
429+
430+ else :
431+ raise NotImplementedError (
432+ "Selected instance segmentation method is not defined"
433+ )
434+
435+ instance_labels = method (to_instance )
396436
397437 instance_filepath = (
398438 self .results_path
@@ -526,10 +566,11 @@ def log(self, text):
526566
527567 def log_parameters (self ):
528568
529- self .log ("\n Parameters summary :\n " )
569+ self .log ("-" * 20 )
570+ self .log ("Parameters summary :\n " )
530571
531572 self .log (
532- f"Percentage of dataset used for validation : { self .validation_percent * 100 } %"
573+ f"Percentage of dataset used for validation : { self .validation_percent * 100 } %"
533574 )
534575 self .log ("-" * 10 )
535576 self .log ("Training files :\n " )
@@ -892,7 +933,7 @@ def train(self):
892933 yield train_report
893934
894935 weights_filename = (
895- f"{ model_name } _best_metric" + f"_epoch_{ epoch + 1 } .pth"
936+ f"{ model_name } _best_metric" + f"_epoch_{ epoch + 1 } .pth"
896937 )
897938
898939 if metric > best_metric :
0 commit comments