Skip to content

Commit e28ca38

Browse files
committed
Fixed model name error
- Replaced all TRAILMAP_PyTorch with TRAILMAP
1 parent 34da28b commit e28ca38

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

napari_cellseg3d/dev_scripts/weight_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from napari_cellseg3d.models.model_TRAILMAP_PyTorch import get_net
6+
from napari_cellseg3d.models.model_TRAILMAP import get_net
77
from napari_cellseg3d.models.unet.model import UNet3D
88

99
# not sure this actually works when put here

napari_cellseg3d/model_framework.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from napari_cellseg3d import utils
1515
from napari_cellseg3d.log_utility import Log
1616
from napari_cellseg3d.models import model_SegResNet as SegResNet
17-
from napari_cellseg3d.models import model_TRAILMAP_PyTorch as TRAILMAP
17+
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1818
from napari_cellseg3d.models import model_VNet as VNet
1919
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
2020
from napari_cellseg3d.plugin_base import BasePluginFolder
@@ -63,7 +63,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6363
self.models_dict = {
6464
"VNet": VNet,
6565
"SegResNet": SegResNet,
66-
"TRAILMAP_PyTorch": TRAILMAP,
66+
"TRAILMAP": TRAILMAP,
6767
"TRAILMAP_MS": TRAILMAP_MS,
6868
}
6969
"""dict: dictionary of available models, with string for widget display as key

napari_cellseg3d/models/model_TRAILMAP_PyTorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
def get_weights_file():
77
# model additionally trained on Mathis/Wyss mesoSPIM data
8-
return "TRAILMAP_PyTorch.pth"
8+
return "TRAILMAP.pth"
99
# FIXME currently incorrect, find good weights from TRAILMAP_test and upload them
1010

1111
def get_net():
12-
return TRAILMAP_PyTorch(1, 1)
12+
return TRAILMAP(1, 1)
1313

1414

1515
def get_output(model, input):
@@ -23,7 +23,7 @@ def get_validation(model, val_inputs):
2323
return model(val_inputs)
2424

2525

26-
class TRAILMAP_PyTorch(nn.Module):
26+
class TRAILMAP(nn.Module):
2727
def __init__(self, in_ch, out_ch):
2828
super().__init__()
2929
self.conv0 = self.encoderBlock(in_ch, 32, 3) # input
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz",
3-
"TRAILMAP_PyTorch": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
3+
"TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
44
"SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz",
55
"VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz"
66
}

0 commit comments

Comments
 (0)