Skip to content

Commit ede3f39

Browse files
committed
🎨 formatting
1 parent cb5116d commit ede3f39

File tree

4 files changed

+66
-56
lines changed

4 files changed

+66
-56
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,19 @@ class InferenceWorker(GeneratorWorker):
167167
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
168168

169169
def __init__(
170-
self,
171-
device,
172-
model_dict,
173-
weights_dict,
174-
images_filepaths,
175-
results_path,
176-
filetype,
177-
transforms,
178-
instance,
179-
use_window,
180-
window_infer_size,
181-
keep_on_cpu,
182-
stats_csv,
170+
self,
171+
device,
172+
model_dict,
173+
weights_dict,
174+
images_filepaths,
175+
results_path,
176+
filetype,
177+
transforms,
178+
instance,
179+
use_window,
180+
window_infer_size,
181+
keep_on_cpu,
182+
stats_csv,
183183
):
184184
"""Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function.
185185
@@ -227,7 +227,7 @@ def __init__(
227227
self.instance_params = instance
228228
self.use_window = use_window
229229
self.window_infer_size = window_infer_size
230-
self.window_overlap_percentage = 0.8,
230+
self.window_overlap_percentage = (0.8,)
231231
self.keep_on_cpu = keep_on_cpu
232232
self.stats_to_csv = stats_csv
233233
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -508,14 +508,14 @@ def inference(self):
508508

509509
# File output save name : original-name_model_date+time_number.filetype
510510
file_path = (
511-
self.results_path
512-
+ "/"
513-
+ f"Prediction_{image_id}_"
514-
+ original_filename
515-
+ "_"
516-
+ self.model_dict["name"]
517-
+ f"_{time}_"
518-
+ self.filetype
511+
self.results_path
512+
+ "/"
513+
+ f"Prediction_{image_id}_"
514+
+ original_filename
515+
+ "_"
516+
+ self.model_dict["name"]
517+
+ f"_{time}_"
518+
+ self.filetype
519519
)
520520

521521
# print(filename)
@@ -556,14 +556,14 @@ def method(image):
556556
instance_labels = method(to_instance)
557557

558558
instance_filepath = (
559-
self.results_path
560-
+ "/"
561-
+ f"Instance_seg_labels_{image_id}_"
562-
+ original_filename
563-
+ "_"
564-
+ self.model_dict["name"]
565-
+ f"_{time}_"
566-
+ self.filetype
559+
self.results_path
560+
+ "/"
561+
+ f"Instance_seg_labels_{image_id}_"
562+
+ original_filename
563+
+ "_"
564+
+ self.model_dict["name"]
565+
+ f"_{time}_"
566+
+ self.filetype
567567
)
568568

569569
imwrite(instance_filepath, instance_labels)
@@ -606,23 +606,23 @@ class TrainingWorker(GeneratorWorker):
606606
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
607607

608608
def __init__(
609-
self,
610-
device,
611-
model_dict,
612-
weights_path,
613-
data_dicts,
614-
validation_percent,
615-
max_epochs,
616-
loss_function,
617-
learning_rate,
618-
val_interval,
619-
batch_size,
620-
results_path,
621-
sampling,
622-
num_samples,
623-
sample_size,
624-
do_augmentation,
625-
deterministic,
609+
self,
610+
device,
611+
model_dict,
612+
weights_path,
613+
data_dicts,
614+
validation_percent,
615+
max_epochs,
616+
loss_function,
617+
learning_rate,
618+
val_interval,
619+
batch_size,
620+
results_path,
621+
sampling,
622+
num_samples,
623+
sample_size,
624+
do_augmentation,
625+
deterministic,
626626
):
627627
"""Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train`
628628
@@ -853,10 +853,10 @@ def train(self):
853853

854854
self.train_files, self.val_files = (
855855
self.data_dicts[
856-
0: int(len(self.data_dicts) * self.validation_percent)
856+
0 : int(len(self.data_dicts) * self.validation_percent)
857857
],
858858
self.data_dicts[
859-
int(len(self.data_dicts) * self.validation_percent):
859+
int(len(self.data_dicts) * self.validation_percent) :
860860
],
861861
)
862862

@@ -1017,10 +1017,10 @@ def train(self):
10171017
if self.device.type == "cuda":
10181018
self.log("Memory Usage:")
10191019
alloc_mem = round(
1020-
torch.cuda.memory_allocated(0) / 1024 ** 3, 1
1020+
torch.cuda.memory_allocated(0) / 1024**3, 1
10211021
)
10221022
reserved_mem = round(
1023-
torch.cuda.memory_reserved(0) / 1024 ** 3, 1
1023+
torch.cuda.memory_reserved(0) / 1024**3, 1
10241024
)
10251025
self.log(f"Allocated: {alloc_mem}GB")
10261026
self.log(f"Cached: {reserved_mem}GB")
@@ -1102,7 +1102,7 @@ def train(self):
11021102
yield train_report
11031103

11041104
weights_filename = (
1105-
f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth"
1105+
f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth"
11061106
)
11071107

11081108
if metric > best_metric:
@@ -1143,6 +1143,7 @@ def train(self):
11431143

11441144
# self.close()
11451145

1146+
11461147
# def this_is_fine(self):
11471148
# import numpy as np
11481149
#

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33

44
def get_net(input_image_size, dropout_prob=None):
5-
return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob)
5+
return SegResNetVAE(
6+
input_image_size, out_channels=1, dropout_prob=dropout_prob
7+
)
68

79

810
def get_weights_file():

napari_cellseg3d/models/model_SwinUNetR.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@ def get_weights_file():
77

88

99
def get_net(img_size, use_checkpoint=True):
10-
return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint)
10+
return SwinUNETR(
11+
img_size,
12+
in_channels=1,
13+
out_channels=1,
14+
feature_size=48,
15+
use_checkpoint=use_checkpoint,
16+
)
1117

1218

1319
def get_output(model, input):

napari_cellseg3d/plugin_model_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
9797
######################
9898
######################
9999
# TODO : better way to handle SegResNet size reqs ?
100-
self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128)
100+
self.model_input_size = ui.IntIncrementCounter(
101+
min=1, max=1024, default=128
102+
)
101103
self.model_choice.currentIndexChanged.connect(
102104
self.toggle_display_model_input_size
103105
)
@@ -226,7 +228,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
226228
"Size of the window to run inference with (in pixels)"
227229
)
228230

229-
230231
self.keep_data_on_cpu_box.setToolTip(
231232
"If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA"
232233
)

0 commit comments

Comments
 (0)