|
4 | 4 | from pathlib import Path |
5 | 5 | from typing import TYPE_CHECKING |
6 | 6 |
|
| 7 | +from monai.data import CacheDataset |
| 8 | + |
7 | 9 | # MONAI |
8 | 10 | from monai.metrics import DiceMetric |
| 11 | +from monai.transforms import ( |
| 12 | + AddChanneld, |
| 13 | + Compose, |
| 14 | + EnsureChannelFirstd, |
| 15 | + EnsureTyped, |
| 16 | + LoadImaged, |
| 17 | + Orientationd, |
| 18 | + SpatialPadd, |
| 19 | +) |
9 | 20 |
|
10 | 21 | # local |
11 | 22 | from napari_cellseg3d import config, utils |
@@ -94,6 +105,53 @@ def __init__( |
94 | 105 | self.eval_dataloader: DataLoader = None |
95 | 106 | self.data_shape = None |
96 | 107 |
|
| 108 | + def get_dataset(self, train_transforms): |
| 109 | + """Creates a Dataset applying some transforms/augmentation on the data using the MONAI library. |
| 110 | +
|
| 111 | + Args: |
| 112 | + train_transforms (monai.transforms.Compose): The transforms to apply to the data |
| 113 | +
|
| 114 | + Returns: |
| 115 | + (tuple): A tuple containing the shape of the data and the dataset |
| 116 | + """ |
| 117 | + train_files = self.config.train_data_dict |
| 118 | + |
| 119 | + first_volume = LoadImaged(keys=["image"])(train_files[0]) |
| 120 | + first_volume_shape = first_volume["image"].shape |
| 121 | + |
| 122 | + if len(first_volume_shape) != 3: |
| 123 | + raise ValueError( |
| 124 | + f"Expected 3D volumes, got {len(first_volume_shape)} dimensions" |
| 125 | + ) |
| 126 | + |
| 127 | + # Transforms to be applied to each volume |
| 128 | + load_single_images = Compose( |
| 129 | + [ |
| 130 | + LoadImaged(keys=["image"]), |
| 131 | + # EnsureChannelFirstd( |
| 132 | + # keys=["image"], |
| 133 | + # channel_dim="no_channel", |
| 134 | + # strict_check=False, |
| 135 | + # ), |
| 136 | + AddChanneld(keys=["image"]), |
| 137 | + Orientationd(keys=["image"], axcodes="PLI"), |
| 138 | + # SpatialPadd( |
| 139 | + # keys=["image"], |
| 140 | + # spatial_size=(utils.get_padding_dim(first_volume_shape)), |
| 141 | + # ), |
| 142 | + EnsureTyped(keys=["image"]), |
| 143 | + # RemapTensord(keys=["image"], new_min=0.0, new_max=100.0), |
| 144 | + ] |
| 145 | + ) |
| 146 | + |
| 147 | + # Create the dataset |
| 148 | + dataset = CacheDataset( |
| 149 | + data=train_files, |
| 150 | + transform=Compose([load_single_images, train_transforms]), |
| 151 | + ) |
| 152 | + |
| 153 | + return first_volume_shape, dataset |
| 154 | + |
97 | 155 |
|
98 | 156 | def get_colab_worker( |
99 | 157 | worker_config: config.WNetTrainingWorkerConfig, |
|
0 commit comments