Skip to content

Commit 2f5cbf4

Browse files
committed
fix: fix inst2gdf empty masks
1 parent 69d1c52 commit 2f5cbf4

File tree

9 files changed

+202
-337
lines changed

9 files changed

+202
-337
lines changed

cellseg_models_pytorch/models/base/_base_model_inst.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,6 @@ def from_pretrained(
5656
f" csmp-hub. One of {list(PRETRAINED[cls.model_name].keys())}."
5757
)
5858

59-
try:
60-
from safetensors.torch import load_model
61-
except ImportError:
62-
raise ImportError(
63-
"Please install `safetensors` package to load .safetensors files."
64-
)
65-
6659
enc_name, n_nuc_classes, state_dict = cls._get_state_dict(
6760
weights_path, device=device
6861
)
@@ -77,6 +70,12 @@ def from_pretrained(
7770
)
7871

7972
if weights_path.suffix == ".safetensors":
73+
try:
74+
from safetensors.torch import load_model
75+
except ImportError:
76+
raise ImportError(
77+
"Please install `safetensors` package to load .safetensors files."
78+
)
8079
load_model(model_inst.model, weights_path, device.type)
8180
else:
8281
model_inst.model.load_state_dict(state_dict, strict=True)

cellseg_models_pytorch/models/cellpose/cellpose_unet.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
__all__ = [
1414
"CellPoseUnet",
1515
"cellpose_nuclei",
16-
"cellpose_panoptic",
1716
"omnipose_nuclei",
18-
"omnipose_panoptic",
1917
]
2018

2119

@@ -249,34 +247,6 @@ def cellpose_nuclei(n_nuc_classes: int, **kwargs) -> nn.Module:
249247
return cellpose_unet
250248

251249

252-
def cellpose_panoptic(n_nuc_classes: int, n_tissue_classes: int, **kwargs) -> nn.Module:
253-
"""Initialize Cellpose for panoptic segmentation.
254-
255-
Cellpose:
256-
- https://www.nature.com/articles/s41592-020-01018-x
257-
258-
Parameters
259-
n_nuc_classes (int):
260-
Number of nuclei type classes.
261-
n_tissue_classes (int):
262-
Number of tissue type classes.
263-
**kwargs:
264-
Arbitrary key word args for the CellPoseUnet class.
265-
266-
Returns:
267-
nn.Module: The initialized Cellpose+ U-net model.
268-
"""
269-
cellpose_unet = CellPoseUnet(
270-
decoders=("type", "tissue"),
271-
heads={
272-
"type": {"nuc_cellpose": 2, "nuc_type": n_nuc_classes},
273-
"tissue": {"tissue_type": n_tissue_classes},
274-
},
275-
**kwargs,
276-
)
277-
return cellpose_unet
278-
279-
280250
def omnipose_nuclei(n_nuc_classes: int, **kwargs) -> nn.Module:
281251
"""Create the baseline Omnipose U-net for nuclei segmentation.
282252
@@ -300,33 +270,3 @@ def omnipose_nuclei(n_nuc_classes: int, **kwargs) -> nn.Module:
300270
cellpose_unet.aux_key = "omnipose"
301271

302272
return cellpose_unet
303-
304-
305-
def omnipose_panoptic(n_nuc_classes: int, n_tissue_classes: int, **kwargs) -> nn.Module:
306-
"""Create the Omnipose U-net with nuclei- and tissue segmentation decoders.
307-
308-
Omnipose:
309-
- https://www.biorxiv.org/content/10.1101/2021.11.03.467199v2
310-
311-
Parameters:
312-
n_nuc_classes (int):
313-
Number of nuclei type classes.
314-
n_tissue_classes (int):
315-
Number of tissue type classes.
316-
**kwargs:
317-
Arbitrary key word args for the CellPoseUnet class.
318-
319-
Returns:
320-
nn.Module: The initialized Cellpose+ U-net model.
321-
"""
322-
cellpose_unet = CellPoseUnet(
323-
decoders=("type", "tissue"),
324-
heads={
325-
"type": {"nuc_omnipose": 2, "nuc_type": n_nuc_classes},
326-
"tissue": {"tissue_type": n_tissue_classes},
327-
},
328-
**kwargs,
329-
)
330-
cellpose_unet.aux_key = "omnipose"
331-
332-
return cellpose_unet

cellseg_models_pytorch/models/cellvit/cellvit_unet.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
__all__ = [
1414
"CellVitSamUnet",
1515
"cellvit_nuclei",
16-
"cellvit_panoptic",
1716
]
1817

1918

@@ -263,39 +262,3 @@ def cellvit_nuclei(enc_name: str, n_nuc_classes: int, **kwargs) -> nn.Module:
263262
)
264263

265264
return cellvit_sam
266-
267-
268-
def cellvit_panoptic(
269-
enc_name: str, n_nuc_classes: int, n_tissue_classes: int, **kwargs
270-
) -> nn.Module:
271-
"""Initialaize CellVit for panoptic segmentation.
272-
273-
CellVit:
274-
- https://arxiv.org/abs/2306.15350
275-
276-
Parameters:
277-
enc_name (str):
278-
Name of the encoder. One of: "samvit_base_patch16", "samvit_base_patch16_224",
279-
"samvit_huge_patch16", "samvit_large_patch16"
280-
n_nuc_classes (int):
281-
Number of nuclei type classes.
282-
n_tissue_classes (int):
283-
Number of tissue type classes.
284-
**kwargs:
285-
Arbitrary key word args for the CellVitSAM class.
286-
287-
Returns:
288-
nn.Module: The initialized CellVitSAM+ model.
289-
"""
290-
cellvit_sam = CellVitSamUnet(
291-
enc_name=enc_name,
292-
decoders=("hovernet", "type", "tissue"),
293-
heads={
294-
"hovernet": {"nuc_hovernet": 2},
295-
"type": {"nuc_type": n_nuc_classes},
296-
"tissue": {"tissue_type": n_tissue_classes},
297-
},
298-
**kwargs,
299-
)
300-
301-
return cellvit_sam

cellseg_models_pytorch/models/cppnet/cppnet_unet.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
__all__ = [
1717
"CPPNetUnet",
1818
"cppnet_nuclei",
19-
"cppnet_panoptic",
2019
]
2120

2221

@@ -338,38 +337,3 @@ def cppnet_nuclei(n_rays: int, n_nuc_classes: int, **kwargs) -> nn.Module:
338337
)
339338

340339
return cppnet
341-
342-
343-
def cppnet_panoptic(
344-
n_rays: int, n_nuc_classes: int, n_tissue_classes: int, **kwargs
345-
) -> nn.Module:
346-
"""Initialaize CPP-Net for panoptic segmentation.
347-
348-
CPP-Net:
349-
- https://arxiv.org/abs/2102.06867
350-
351-
Parameters:
352-
n_rays (int):
353-
Number of rays predicted per each object
354-
n_nuc_classes (int):
355-
Number of nuclei type classes.
356-
n_tissue_classes (int):
357-
Number of tissue type classes.
358-
**kwargs:
359-
Arbitrary key word args for the CPPNet class.
360-
361-
Returns:
362-
nn.Module: The initialized CPP-Net model.
363-
"""
364-
cppnet = CPPNetUnet(
365-
decoders=("stardist", "type", "tissue"),
366-
heads={
367-
"stardist": {"nuc_stardist": n_rays, "nuc_binary": 1},
368-
"type": {"nuc_type": n_nuc_classes},
369-
"tissue": {"tissue_type": n_tissue_classes},
370-
},
371-
n_rays=n_rays,
372-
**kwargs,
373-
)
374-
375-
return cppnet

cellseg_models_pytorch/models/hovernet/hovernet_unet.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
__all__ = [
1414
"HoverNetUnet",
1515
"hovernet_nuclei",
16-
"hovernet_panoptic",
1716
]
1817

1918

@@ -244,37 +243,3 @@ def hovernet_nuclei(n_nuc_classes: int, **kwargs) -> nn.Module:
244243
)
245244

246245
return hovernet
247-
248-
249-
def hovernet_panoptic(
250-
n_nuc_classes: int,
251-
n_tissue_classes: int,
252-
**kwargs,
253-
) -> nn.Module:
254-
"""Initialaize HoverNet+ for panoptic segmentation.
255-
256-
HoVer-Net:
257-
- https://www.sciencedirect.com/science/article/pii/S1361841519301045?via%3Dihub
258-
259-
Parameters:
260-
n_nuc_classes (int):
261-
Number of nuclei type classes.
262-
n_tissue_classes (int):
263-
Number of tissue type classes.
264-
**kwargs:
265-
Arbitrary key word args for the HoverNet class.
266-
267-
Returns:
268-
nn.Module: The initialized HoVer-Net+ model.
269-
"""
270-
hovernet = HoverNetUnet(
271-
decoders=("hovernet", "type", "tissue"),
272-
heads={
273-
"hovernet": {"nuc_hovernet": 2},
274-
"type": {"nuc_type": n_nuc_classes},
275-
"tissue": {"tissue_type": n_tissue_classes},
276-
},
277-
**kwargs,
278-
)
279-
280-
return hovernet

cellseg_models_pytorch/models/stardist/stardist_unet.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from cellseg_models_pytorch.encoders import Encoder
1111
from cellseg_models_pytorch.models.stardist._conf import _create_stardist_args
1212

13-
__all__ = ["StarDistUnet", "stardist_nuclei", "stardist_panoptic"]
13+
__all__ = ["StarDistUnet", "stardist_nuclei"]
1414

1515

1616
class StarDistUnet(nn.ModuleDict):
@@ -247,40 +247,3 @@ def stardist_nuclei(n_rays: int, n_nuc_classes: int, **kwargs) -> nn.Module:
247247
)
248248

249249
return stardist_unet
250-
251-
252-
def stardist_panoptic(
253-
n_rays: int, n_nuc_classes: int, n_tissue_classes: int, **kwargs
254-
) -> nn.Module:
255-
"""Initialize Stardist model for panoptic segmentation.
256-
257-
Stardist:
258-
- https://arxiv.org/abs/1806.03535
259-
260-
Parameters:
261-
n_rays (int):
262-
Number of rays predicted per each object
263-
n_nuc_classes (int):
264-
Number of nuclei type classes.
265-
n_tissue_classes (int):
266-
Number of tissue type classes.
267-
**kwargs:
268-
Arbitrary key word args for the StarDistUnet class.
269-
270-
Returns:
271-
nn.Module: The initialized Panoptic Stardist model.
272-
"""
273-
stardist_unet = StarDistUnet(
274-
decoders=("stardist", "tissue"),
275-
heads={
276-
"stardist": {
277-
"nuc_stardist": n_rays,
278-
"nuc_binary": 1,
279-
"nuc_type": n_nuc_classes,
280-
},
281-
"tissue": {"tissue_type": n_tissue_classes},
282-
},
283-
**kwargs,
284-
)
285-
286-
return stardist_unet

0 commit comments

Comments
 (0)