Skip to content

Commit 34da28b

Browse files
C-AchardMMathisLab
andcommitted
Fixed model names mismatch
- TRAILMAP is now TRAILMAP_MS, as intended initially - TRAILMAP_MS is now TRAILMAP_PyTorch.py, was incorrectly renamed - Fixed "Pytorch" typos - Improved logging function - Fixed redundant imports Co-Authored-By: Mackenzie Mathis <[email protected]>
1 parent 272449e commit 34da28b

File tree

10 files changed

+21
-30
lines changed

10 files changed

+21
-30
lines changed

napari_cellseg3d/dev_scripts/weight_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from napari_cellseg3d.models.model_TRAILMAP import get_net
6+
from napari_cellseg3d.models.model_TRAILMAP_PyTorch import get_net
77
from napari_cellseg3d.models.unet.model import UNet3D
88

99
# not sure this actually works when put here

napari_cellseg3d/log_utility.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,20 @@ def replace_last_line(self, text):
5656
finally:
5757
self.lock.release()
5858

59-
def print_and_log(self, text):
59+
def print_and_log(self, text, printing=True):
6060
"""Utility used to both print to terminal and log text to a QTextEdit
6161
item in a thread-safe manner. Use only for important user info.
6262
6363
Args:
6464
text (str): Text to be printed and logged
65+
printing (bool): Whether to print the message as well or not using print(). Defaults to True.
6566
6667
"""
6768
self.lock.acquire()
6869
try:
69-
print(text)
70-
# causes issue if you clik on terminal (tied to CMD QuickEdit mode)
70+
if printing:
71+
print(text)
72+
# causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows)
7173
self.moveCursor(QTextCursor.End)
7274
self.insertPlainText(f"\n{text}")
7375
self.verticalScrollBar().setValue(

napari_cellseg3d/model_framework.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from napari_cellseg3d import utils
1515
from napari_cellseg3d.log_utility import Log
1616
from napari_cellseg3d.models import model_SegResNet as SegResNet
17-
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
17+
from napari_cellseg3d.models import model_TRAILMAP_PyTorch as TRAILMAP
1818
from napari_cellseg3d.models import model_VNet as VNet
19-
from napari_cellseg3d.models import TRAILMAP_MS as TMAP
19+
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
2020
from napari_cellseg3d.plugin_base import BasePluginFolder
2121

2222
warnings.formatwarning = utils.format_Warning
@@ -63,8 +63,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6363
self.models_dict = {
6464
"VNet": VNet,
6565
"SegResNet": SegResNet,
66-
"TRAILMAP pre-trained": TRAILMAP,
67-
"TRAILMAP_MS": TMAP,
66+
"TRAILMAP_PyTorch": TRAILMAP,
67+
"TRAILMAP_MS": TRAILMAP_MS,
6868
}
6969
"""dict: dictionary of available models, with string for widget display as key
7070

napari_cellseg3d/model_workers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def show_progress(count, block_size, total_size):
114114
tar.extractall(pretrained_folder_path)
115115
else:
116116
raise ValueError(
117-
f"Unknown model. `modelname` should be one of {', '.join(neturls)}"
117+
f"Unknown model: {model_name}. Should be one of {', '.join(neturls)}"
118118
)
119119

120120

@@ -695,7 +695,7 @@ def log_parameters(self):
695695
self.log("-" * 20)
696696

697697
def train(self):
698-
"""Trains the Pytorch model for the given number of epochs, with the selected model and data,
698+
"""Trains the PyTorch model for the given number of epochs, with the selected model and data,
699699
using the chosen batch size, validation interval, loss function, and number of samples.
700700
Will perform validation once every :py:obj:`val_interval` and save results if the mean dice is better
701701

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import os
2-
31
from monai.networks.nets import SegResNetVAE
42

5-
from napari_cellseg3d import utils
63

74

85
def get_net():

napari_cellseg3d/models/model_TRAILMAP.py renamed to napari_cellseg3d/models/model_TRAILMAP_MS.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import os
2-
3-
from napari_cellseg3d import utils
41
from napari_cellseg3d.models.unet.model import UNet3D
52

63

74
def get_weights_file():
8-
# original model from Liqun Luo lab, transfered to pytorch
9-
return "TRAILMAP_PyTorch.pth"
5+
# original model from Liqun Luo lab, transferred to pytorch
6+
return "TRAILMAP_MS_best_metric_epoch_26.pth"
107

118

129
def get_net():

napari_cellseg3d/models/TRAILMAP_MS.py renamed to napari_cellseg3d/models/model_TRAILMAP_PyTorch.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
import os
2-
31
import torch
42
from torch import nn
53

6-
from napari_cellseg3d import utils
74

85

96
def get_weights_file():
107
# model additionally trained on Mathis/Wyss mesoSPIM data
11-
return "TRAILMAP_MS_best_metric_epoch_26.pth"
12-
8+
return "TRAILMAP_PyTorch.pth"
9+
# FIXME currently incorrect, find good weights from TRAILMAP_test and upload them
1310

1411
def get_net():
15-
return TRAILMAP_MS(1, 1)
12+
return TRAILMAP_PyTorch(1, 1)
1613

1714

1815
def get_output(model, input):
@@ -26,7 +23,7 @@ def get_validation(model, val_inputs):
2623
return model(val_inputs)
2724

2825

29-
class TRAILMAP_MS(nn.Module):
26+
class TRAILMAP_PyTorch(nn.Module):
3027
def __init__(self, in_ch, out_ch):
3128
super().__init__()
3229
self.conv0 = self.encoderBlock(in_ch, 32, 3) # input
@@ -123,3 +120,4 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"):
123120
# nn.BatchNorm3d(out_ch),
124121
)
125122
return out
123+

napari_cellseg3d/models/model_VNet.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import os
2-
31
from monai.inferers import sliding_window_inference
42
from monai.networks.nets import VNet
53

6-
from napari_cellseg3d import utils
74

85

96
def get_net():
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz",
3-
"TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
3+
"TRAILMAP_PyTorch": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
44
"SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz",
55
"VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz"
66
}

napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
class Trainer(ModelFramework):
38-
"""A plugin to train pre-defined Pytorch models for one-channel segmentation directly in napari.
38+
"""A plugin to train pre-defined PyTorch models for one-channel segmentation directly in napari.
3939
Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during
4040
training through validation."""
4141

0 commit comments

Comments
 (0)