Skip to content

Commit 3c1e023

Browse files
committed
🎨 formatting
1 parent 58b831e commit 3c1e023

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def inference(self):
683683
try:
684684
dims = self.model_dict["segres_size"]
685685

686-
model = self.model_dict["class"].get_net()
686+
687687
if self.model_dict["name"] == "SegResNet":
688688
model = self.model_dict["class"].get_net()(
689689
input_image_size=[
@@ -694,6 +694,14 @@ def inference(self):
694694
out_channels=1,
695695
# dropout_prob=0.3,
696696
)
697+
elif self.model_dict["name"] == "SwinUNetR":
698+
model = self.model_dict["class"].get_net(
699+
img_size=[dims, dims, dims],
700+
use_checkpoint=False,
701+
)
702+
else:
703+
model = self.model_dict["class"].get_net()
704+
model = model.to(self.device)
697705

698706
self.log_parameters()
699707

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33

44
def get_net(input_image_size, dropout_prob=None):
5-
return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob)
5+
return SegResNetVAE(
6+
input_image_size, out_channels=1, dropout_prob=dropout_prob
7+
)
68

79

810
def get_weights_file():

napari_cellseg3d/models/model_SwinUNetR.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@ def get_weights_file():
77

88

99
def get_net(img_size, use_checkpoint=True):
10-
return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint)
10+
return SwinUNETR(
11+
img_size,
12+
in_channels=1,
13+
out_channels=1,
14+
feature_size=48,
15+
use_checkpoint=use_checkpoint,
16+
)
1117

1218

1319
def get_output(model, input):

0 commit comments

Comments
 (0)