1- import glob
21import os
32import warnings
43
87from qtpy .QtWidgets import QLineEdit
98from qtpy .QtWidgets import QProgressBar
109from qtpy .QtWidgets import QSizePolicy
11- from qtpy .QtWidgets import QTabWidget
1210
1311# local
1412from napari_cellseg_annotator import interface as ui
1715from napari_cellseg_annotator .models import TRAILMAP_test as TMAP
1816from napari_cellseg_annotator .models import model_SegResNet as SegResNet
1917from napari_cellseg_annotator .models import model_VNet as VNet
18+ from napari_cellseg_annotator .plugin_base import BasePluginFolder
2019
2120warnings .formatwarning = utils .format_Warning
2221
2322
24- class ModelFramework (QTabWidget ):
23+ class ModelFramework (BasePluginFolder ):
2524 """A framework with buttons to use for loading images, labels, models, etc. for both inference and training"""
2625
2726 def __init__ (self , viewer : "napari.viewer.Viewer" ):
@@ -40,20 +39,21 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
4039 Args:
4140 viewer (napari.viewer.Viewer): viewer to load the widget in
4241 """
43- super ().__init__ ()
42+ super ().__init__ (viewer )
4443
4544 self ._viewer = viewer
46- """napari.viewer.Viewer: Viewer to display the widget in in"""
47-
48- self .images_filepaths = ["" ]
49- """array(str): paths to images for training or inference"""
50- self .labels_filepaths = ["" ]
51- """array(str): paths to labels for training"""
52- self .results_path = ""
53- """str: path to output folder,to save results in"""
45+
5446 self .model_path = ""
5547 """str: path to custom model defined by user"""
5648
49+ self ._default_path = [
50+ self .images_filepaths ,
51+ self .labels_filepaths ,
52+ self .model_path ,
53+ self .results_path ,
54+ ]
55+ """Update defaults from PluginBaseFolder with model_path"""
56+
5757 self .models_dict = {
5858 "VNet" : VNet ,
5959 "SegResNet" : SegResNet ,
@@ -63,43 +63,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6363
6464 Currently implemented : SegResNet, VNet, TRAILMAP_test"""
6565
66- self ._default_path = [
67- self .images_filepaths ,
68- self .labels_filepaths ,
69- self .model_path ,
70- self .results_path ,
71- ]
72-
73- self .docked_widgets = []
74- """List of docked widgets (returned by :py:func:`viewer.window.add_dock_widget()),
75- can be used to remove docked widgets`"""
76-
7766 self .worker = None
7867 """Worker from model_workers.py, either inference or training"""
7968
8069 #######################################################
8170 # interface
82- self .btn_image_files = ui .make_button (
83- "Open" , self .load_image_dataset , self
84- )
85- self .lbl_image_files = QLineEdit ("Images directory" , self )
86- self .lbl_image_files .setReadOnly (True )
87-
88- self .btn_label_files = ui .make_button (
89- "Open" , self .load_label_dataset , self
90- )
91- self .lbl_label_files = QLineEdit ("Labels directory" , self )
92- self .lbl_label_files .setReadOnly (True )
93-
94- self .filetype_choice , self .lbl_filetype = ui .make_combobox (
95- [".tif" , ".tiff" ], label = "File format"
96- )
97-
98- self .btn_result_path = ui .make_button (
99- "Open" , self .load_results_path , self
100- )
101- self .lbl_result_path = QLineEdit ("Results directory" , self )
102- self .lbl_result_path .setReadOnly (True )
10371
10472 # TODO : implement custom model
10573 self .btn_model_path = ui .make_button (
@@ -175,7 +143,8 @@ def save_log(self):
175143 )
176144
177145 def display_status_report (self ):
178- """Adds a text log, a progress bar and a "save log" button on the left side of the viewer (usually when starting a worker)"""
146+ """Adds a text log, a progress bar and a "save log" button on the left side of the viewer
147+ (usually when starting a worker)"""
179148
180149 # if self.container_report is None or self.log is None:
181150 # warnings.warn(
@@ -223,32 +192,6 @@ def display_status_report(self):
223192 self .btn_save_log .setVisible (True )
224193 self .progress .setValue (0 )
225194
226- def update_default (self ):
227- """Update default path for smoother file dialogs"""
228- self ._default_path = [
229- path
230- for path in [
231- os .path .dirname (self .images_filepaths [0 ]),
232- os .path .dirname (self .labels_filepaths [0 ]),
233- self .model_path ,
234- self .results_path ,
235- ]
236- if (path != ["" ] and path != "" )
237- ]
238-
239- def load_dataset_paths (self ):
240- """Loads all image paths (as str) in a given folder for which the extension matches the set filetype
241-
242- Returns:
243- array(str): all loaded file paths
244- """
245- filetype = self .filetype_choice .currentText ()
246- directory = ui .open_file_dialog (self , self ._default_path , True )
247- # print(directory)
248- file_paths = sorted (glob .glob (os .path .join (directory , "*" + filetype )))
249- # print(file_paths)
250- return file_paths
251-
252195 def create_train_dataset_dict (self ):
253196 """Creates data dictionary for MONAI transforms and training.
254197
@@ -285,35 +228,6 @@ def get_loss(self, key):
285228 """Getter for loss function selected by user"""
286229 return self .loss_dict [key ]
287230
288- def load_image_dataset (self ):
289- """Show file dialog to set :py:attr:`~images_filepaths`"""
290- filenames = self .load_dataset_paths ()
291- # print(filenames)
292- if filenames != "" and filenames != ["" ] and filenames != []:
293- self .images_filepaths = filenames
294- # print(filenames)
295- path = os .path .dirname (filenames [0 ])
296- self .lbl_image_files .setText (path )
297- # print(path)
298- self ._default_path [0 ] = path
299-
300- def load_label_dataset (self ):
301- """Show file dialog to set :py:attr:`~labels_filepaths`"""
302- filenames = self .load_dataset_paths ()
303- if filenames != "" and filenames != ["" ]:
304- self .labels_filepaths = filenames
305- path = os .path .dirname (filenames [0 ])
306- self .lbl_label_files .setText (path )
307- self .update_default ()
308-
309- def load_results_path (self ):
310- """Show file dialog to set :py:attr:`~results_path`"""
311- dir = ui .open_file_dialog (self , self ._default_path , True )
312- if dir != "" and type (dir ) is str and os .path .isdir (dir ):
313- self .results_path = dir
314- self .lbl_result_path .setText (self .results_path )
315- self .update_default ()
316-
317231 def load_model_path (self ):
318232 """Show file dialog to set :py:attr:`model_path`"""
319233 dir = ui .open_file_dialog (self , self ._default_path )
@@ -340,20 +254,18 @@ def empty_cuda_cache(self):
340254 torch .cuda .empty_cache ()
341255 print ("Cache emptied" )
342256
343- def build (self ):
344- raise NotImplementedError ( "Should be defined in children classes" )
345-
346- def remove_docked_widgets ( self ):
347- """Removes docked widgets and resets checks for status report"""
348- if len (self .docked_widgets ) != 0 :
349- [
350- self ._viewer . window . remove_dock_widget ( w )
351- for w in self .docked_widgets
257+ def update_default (self ):
258+ """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included"""
259+ self . _default_path = [
260+ path
261+ for path in [
262+ os . path . dirname (self .images_filepaths [ 0 ]),
263+ os . path . dirname ( self . labels_filepaths [ 0 ]),
264+ self .model_path ,
265+ self .results_path ,
352266 ]
353- self . docked_widgets = []
354- self . container_docked = False
267+ if ( path ! = ["" ] and path != "" )
268+ ]
355269
356- def close (self ):
357- """Close the widget and the docked widgets, if any"""
358- self .remove_docked_widgets ()
359- self ._viewer .window .remove_dock_widget (self )
270+ def build (self ):
271+ raise NotImplementedError ("Should be defined in children classes" )
0 commit comments