Skip to content

Commit b647d98

Browse files
committed
Move issues with transforms to colab script + disable pad/channelfirst
1 parent e3286ce commit b647d98

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ def get_patch_dataset(self, train_transforms):
201201
patch_func = Compose(
202202
[
203203
LoadImaged(keys=["image"], image_only=True),
204-
# EnsureChannelFirstd(
205-
# keys=["image"],
206-
# channel_dim="no_channel",
207-
# strict_check=False,
208-
# ),
204+
EnsureChannelFirstd(
205+
keys=["image"],
206+
channel_dim="no_channel",
207+
strict_check=False,
208+
),
209209
RandSpatialCropSamplesd(
210210
keys=["image"],
211211
roi_size=(
@@ -287,11 +287,11 @@ def get_dataset(self, train_transforms):
287287
load_single_images = Compose(
288288
[
289289
LoadImaged(keys=["image"]),
290-
# EnsureChannelFirstd(
291-
# keys=["image"],
292-
# channel_dim="no_channel",
293-
# strict_check=False,
294-
# ),
290+
EnsureChannelFirstd(
291+
keys=["image"],
292+
channel_dim="no_channel",
293+
strict_check=False,
294+
),
295295
Orientationd(keys=["image"], axcodes="PLI"),
296296
SpatialPadd(
297297
keys=["image"],

napari_cellseg3d/dev_scripts/colab_training.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,19 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING
66

7+
from monai.data import CacheDataset
8+
79
# MONAI
810
from monai.metrics import DiceMetric
11+
from monai.transforms import (
12+
AddChanneld,
13+
Compose,
14+
EnsureChannelFirstd,
15+
EnsureTyped,
16+
LoadImaged,
17+
Orientationd,
18+
SpatialPadd,
19+
)
920

1021
# local
1122
from napari_cellseg3d import config, utils
@@ -94,6 +105,53 @@ def __init__(
94105
self.eval_dataloader: DataLoader = None
95106
self.data_shape = None
96107

108+
def get_dataset(self, train_transforms):
109+
"""Creates a Dataset applying some transforms/augmentation on the data using the MONAI library.
110+
111+
Args:
112+
train_transforms (monai.transforms.Compose): The transforms to apply to the data
113+
114+
Returns:
115+
(tuple): A tuple containing the shape of the data and the dataset
116+
"""
117+
train_files = self.config.train_data_dict
118+
119+
first_volume = LoadImaged(keys=["image"])(train_files[0])
120+
first_volume_shape = first_volume["image"].shape
121+
122+
if len(first_volume_shape) != 3:
123+
raise ValueError(
124+
f"Expected 3D volumes, got {len(first_volume_shape)} dimensions"
125+
)
126+
127+
# Transforms to be applied to each volume
128+
load_single_images = Compose(
129+
[
130+
LoadImaged(keys=["image"]),
131+
# EnsureChannelFirstd(
132+
# keys=["image"],
133+
# channel_dim="no_channel",
134+
# strict_check=False,
135+
# ),
136+
AddChanneld(keys=["image"]),
137+
Orientationd(keys=["image"], axcodes="PLI"),
138+
# SpatialPadd(
139+
# keys=["image"],
140+
# spatial_size=(utils.get_padding_dim(first_volume_shape)),
141+
# ),
142+
EnsureTyped(keys=["image"]),
143+
# RemapTensord(keys=["image"], new_min=0.0, new_max=100.0),
144+
]
145+
)
146+
147+
# Create the dataset
148+
dataset = CacheDataset(
149+
data=train_files,
150+
transform=Compose([load_single_images, train_transforms]),
151+
)
152+
153+
return first_volume_shape, dataset
154+
97155

98156
def get_colab_worker(
99157
worker_config: config.WNetTrainingWorkerConfig,

0 commit comments

Comments
 (0)