Skip to content

Commit 0d6ec22

Browse files
committed
fix(datasets): move infer datset, add err catch
1 parent 7c77a2a commit 0d6ec22

File tree

7 files changed

+74
-76
lines changed

7 files changed

+74
-76
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .folder_dataset import FolderDataset
21
from .folder_dataset_train import SegmentationFolderDataset
2+
from .hdf5_dataset import SegmentationHDF5Dataset
33

4-
__all__ = ["FolderDataset", "SegmentationFolderDataset"]
4+
__all__ = ["FolderDataset", "SegmentationFolderDataset", "SegmentationHDF5Dataset"]

cellseg_models_pytorch/datasets/_base_dataset.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
import numpy as np
44
from torch.utils.data import Dataset
55

6-
from ..transforms import (
7-
IMG_TRANSFORMS,
8-
INST_TRANSFORMS,
9-
NORM_TRANSFORMS,
10-
apply_each,
11-
compose,
12-
to_tensorv3,
13-
)
6+
try:
7+
from ..transforms.albu_transforms import (
8+
IMG_TRANSFORMS,
9+
INST_TRANSFORMS,
10+
NORM_TRANSFORMS,
11+
apply_each,
12+
compose,
13+
to_tensorv3,
14+
)
15+
except ModuleNotFoundError:
16+
raise ModuleNotFoundError(
17+
"To use the `csmp.dataset` module, the albumentations lib is needed. "
18+
"Install with `pip install albumentations`"
19+
)
1420

1521
__all__ = ["TrainDatasetBase"]
1622

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,47 @@
1-
# import pytest
2-
# import torch
3-
4-
# from cellseg_models_pytorch.datasets import SegmentationHDF5Dataset
5-
6-
# img_transforms = ["rigid", "blur"]
7-
# inst_transforms = ["smooth_dist"]
8-
9-
10-
# @pytest.mark.optional
11-
# @pytest.mark.parametrize("return_inst", [True, False])
12-
# @pytest.mark.parametrize("return_type", [True, False])
13-
# @pytest.mark.parametrize("return_sem", [True, False])
14-
# @pytest.mark.parametrize("normalization", [None, "minmax"])
15-
# def test_hdf5_dataset(hdf5db, return_inst, return_type, return_sem, normalization):
16-
# ds = SegmentationHDF5Dataset(
17-
# path=hdf5db,
18-
# img_transforms=img_transforms,
19-
# inst_transforms=inst_transforms,
20-
# normalization=normalization,
21-
# return_inst=return_inst,
22-
# return_type=return_type,
23-
# return_sem=return_sem,
24-
# )
25-
26-
# out = next(iter(ds))
27-
28-
# if return_inst:
29-
# assert "inst" in out.keys()
30-
# assert out["inst"].dtype == torch.int64
31-
# else:
32-
# assert "binary" not in out.keys()
33-
34-
# if return_type:
35-
# assert "type" in out.keys()
36-
# assert out["type"].dtype == torch.int64
37-
# else:
38-
# assert "type" not in out.keys()
39-
40-
# if return_sem:
41-
# assert "sem" in out.keys()
42-
# assert out["sem"].dtype == torch.int64
43-
# else:
44-
# assert "sem" not in out.keys()
45-
46-
# assert "smoothdist" in out.keys()
47-
# assert out["image"].dtype == torch.float32
1+
import pytest
2+
import torch
3+
4+
from cellseg_models_pytorch.datasets.hdf5_dataset import SegmentationHDF5Dataset
5+
6+
img_transforms = ["rigid", "blur"]
7+
inst_transforms = ["smooth_dist"]
8+
9+
10+
@pytest.mark.optional
11+
@pytest.mark.parametrize("return_inst", [True, False])
12+
@pytest.mark.parametrize("return_type", [True, False])
13+
@pytest.mark.parametrize("return_sem", [True, False])
14+
@pytest.mark.parametrize("normalization", [None, "minmax"])
15+
def test_hdf5_dataset(hdf5db, return_inst, return_type, return_sem, normalization):
16+
ds = SegmentationHDF5Dataset(
17+
path=hdf5db,
18+
img_transforms=img_transforms,
19+
inst_transforms=inst_transforms,
20+
normalization=normalization,
21+
return_inst=return_inst,
22+
return_type=return_type,
23+
return_sem=return_sem,
24+
)
25+
26+
out = next(iter(ds))
27+
28+
if return_inst:
29+
assert "inst" in out.keys()
30+
assert out["inst"].dtype == torch.int64
31+
else:
32+
assert "binary" not in out.keys()
33+
34+
if return_type:
35+
assert "type" in out.keys()
36+
assert out["type"].dtype == torch.int64
37+
else:
38+
assert "type" not in out.keys()
39+
40+
if return_sem:
41+
assert "sem" in out.keys()
42+
assert out["sem"].dtype == torch.int64
43+
else:
44+
assert "sem" not in out.keys()
45+
46+
assert "smoothdist" in out.keys()
47+
assert out["image"].dtype == torch.float32
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
from .folder_dataset import FolderDataset
12
from .post_processor import PostProcessor
23
from .predictor import Predictor
34
from .resize_inferer import ResizeInferer
45
from .sliding_window_inferer import SlidingWindowInferer
56

6-
__all__ = ["Predictor", "PostProcessor", "ResizeInferer", "SlidingWindowInferer"]
7+
__all__ = [
8+
"Predictor",
9+
"PostProcessor",
10+
"ResizeInferer",
11+
"SlidingWindowInferer",
12+
"FolderDataset",
13+
]

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from torch.utils.data import DataLoader
1212
from tqdm import tqdm
1313

14-
from ..datasets import FolderDataset
1514
from ..utils import tensor_to_ndarray
1615
from ..utils.save_utils import mask2mat
16+
from .folder_dataset import FolderDataset
1717
from .post_processor import PostProcessor
1818
from .predictor import Predictor
1919

cellseg_models_pytorch/inference/tests/test_inference.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,14 @@
22

33
from cellseg_models_pytorch.inference import ResizeInferer, SlidingWindowInferer
44
from cellseg_models_pytorch.models import cellpose_plus
5-
from cellseg_models_pytorch.training import SegmentationExperiment
65

76

87
@pytest.mark.parametrize("batch_size", [1, 2])
98
def test_slidingwin_inference(img_dir, batch_size):
109
model = cellpose_plus(sem_classes=3, type_classes=3, long_skip="unet")
1110

12-
experiment = SegmentationExperiment(
13-
model=model,
14-
branch_losses={"cellpose": "mse_ssim", "sem": "ce_dice", "type": "ce_dice"},
15-
branch_metrics={"cellpose": [None], "sem": ["miou"], "type": ["miou"]},
16-
lookahead=False,
17-
)
18-
1911
inferer = SlidingWindowInferer(
20-
experiment,
12+
model,
2113
img_dir,
2214
out_activations={"sem": "softmax", "type": "softmax", "cellpose": "tanh"},
2315
out_boundary_weights={"sem": False, "type": False, "cellpose": True},
@@ -42,15 +34,8 @@ def test_slidingwin_inference(img_dir, batch_size):
4234
def test_resize_inference(img_dir, batch_size):
4335
model = cellpose_plus(sem_classes=3, type_classes=3, long_skip="unet")
4436

45-
experiment = SegmentationExperiment(
46-
model=model,
47-
branch_losses={"cellpose": "mse_ssim", "sem": "ce_dice", "type": "ce_dice"},
48-
branch_metrics={"cellpose": [None], "sem": ["miou"], "type": ["miou"]},
49-
lookahead=False,
50-
)
51-
5237
inferer = ResizeInferer(
53-
experiment,
38+
model,
5439
img_dir,
5540
out_activations={"sem": "softmax", "type": "softmax", "cellpose": "tanh"},
5641
out_boundary_weights={"sem": False, "type": False, "cellpose": True},

0 commit comments

Comments
 (0)