|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +import torch |
| 4 | +import zarr |
| 5 | + |
| 6 | +from skimage.filters import gaussian |
| 7 | +from torch_em.util import ensure_tensor_with_channels |
| 8 | + |
| 9 | + |
| 10 | +# Process labels stored in json napari style. |
| 11 | +# I don't actually think that we need the epsilon here, but will leave it for now. |
| 12 | +def process_labels(label_path, shape, sigma, eps, bb=None): |
| 13 | + points = pd.read_csv(label_path) |
| 14 | + |
| 15 | + if bb: |
| 16 | + (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] |
| 17 | + restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) |
| 18 | + labels = np.zeros(restricted_shape, dtype="float32") |
| 19 | + shape = restricted_shape |
| 20 | + else: |
| 21 | + labels = np.zeros(shape, dtype="float32") |
| 22 | + |
| 23 | + assert len(points.columns) == len(shape) |
| 24 | + z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"] |
| 25 | + if bb is not None: |
| 26 | + z_coords -= z_min |
| 27 | + y_coords -= y_min |
| 28 | + x_coords -= x_min |
| 29 | + mask = np.logical_and.reduce([ |
| 30 | + np.logical_and(z_coords >= 0, z_coords < (z_max - z_min)), |
| 31 | + np.logical_and(y_coords >= 0, y_coords < (y_max - y_min)), |
| 32 | + np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)), |
| 33 | + ]) |
| 34 | + z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask] |
| 35 | + |
| 36 | + coords = tuple( |
| 37 | + np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip( |
| 38 | + (z_coords, y_coords, x_coords), shape |
| 39 | + ) |
| 40 | + ) |
| 41 | + |
| 42 | + labels[coords] = 1 |
| 43 | + labels = gaussian(labels, sigma) |
| 44 | + # TODO better normalization? |
| 45 | + labels /= (labels.max() + 1e-7) |
| 46 | + labels *= 4 |
| 47 | + return labels |
| 48 | + |
| 49 | + |
| 50 | +class DetectionDataset(torch.utils.data.Dataset): |
| 51 | + max_sampling_attempts = 500 |
| 52 | + |
| 53 | + @staticmethod |
| 54 | + def compute_len(shape, patch_shape): |
| 55 | + if patch_shape is None: |
| 56 | + return 1 |
| 57 | + else: |
| 58 | + n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) |
| 59 | + return n_samples |
| 60 | + |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + raw_path, |
| 64 | + label_path, |
| 65 | + patch_shape, |
| 66 | + raw_key, |
| 67 | + raw_transform=None, |
| 68 | + label_transform=None, |
| 69 | + transform=None, |
| 70 | + dtype=torch.float32, |
| 71 | + label_dtype=torch.float32, |
| 72 | + n_samples=None, |
| 73 | + sampler=None, |
| 74 | + eps=1e-8, |
| 75 | + sigma=None, |
| 76 | + **kwargs, |
| 77 | + ): |
| 78 | + self.raw_path = raw_path |
| 79 | + self.label_path = label_path |
| 80 | + self.raw_key = raw_key |
| 81 | + self._ndim = 3 |
| 82 | + |
| 83 | + assert len(patch_shape) == self._ndim |
| 84 | + self.patch_shape = patch_shape |
| 85 | + |
| 86 | + self.raw_transform = raw_transform |
| 87 | + self.label_transform = label_transform |
| 88 | + self.transform = transform |
| 89 | + self.sampler = sampler |
| 90 | + |
| 91 | + self.dtype = dtype |
| 92 | + self.label_dtype = label_dtype |
| 93 | + |
| 94 | + self.eps = eps |
| 95 | + self.sigma = sigma |
| 96 | + |
| 97 | + with zarr.open(self.raw_path, "r") as f: |
| 98 | + self.shape = f[self.raw_key].shape |
| 99 | + |
| 100 | + if n_samples is None: |
| 101 | + self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples |
| 102 | + else: |
| 103 | + self._len = n_samples |
| 104 | + |
| 105 | + def __len__(self): |
| 106 | + return self._len |
| 107 | + |
| 108 | + @property |
| 109 | + def ndim(self): |
| 110 | + return self._ndim |
| 111 | + |
| 112 | + def _sample_bounding_box(self, shape): |
| 113 | + if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): |
| 114 | + raise NotImplementedError( |
| 115 | + f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}" |
| 116 | + ) |
| 117 | + bb_start = [ |
| 118 | + np.random.randint(0, sh - psh) if sh - psh > 0 else 0 |
| 119 | + for sh, psh in zip(shape, self.patch_shape) |
| 120 | + ] |
| 121 | + return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) |
| 122 | + |
| 123 | + def _get_sample(self, index): |
| 124 | + raw, label_path = self.raw_path, self.label_path |
| 125 | + |
| 126 | + raw = zarr.open(raw)[self.raw_key] |
| 127 | + shape = raw.shape |
| 128 | + |
| 129 | + bb = self._sample_bounding_box(shape) |
| 130 | + label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb) |
| 131 | + |
| 132 | + have_raw_channels = raw.ndim == 4 # 3D with channels |
| 133 | + have_label_channels = label.ndim == 4 |
| 134 | + if have_label_channels: |
| 135 | + raise NotImplementedError("Multi-channel labels are not supported.") |
| 136 | + |
| 137 | + prefix_box = tuple() |
| 138 | + if have_raw_channels: |
| 139 | + if shape[-1] < 16: |
| 140 | + shape = shape[:-1] |
| 141 | + else: |
| 142 | + shape = shape[1:] |
| 143 | + prefix_box = (slice(None), ) |
| 144 | + |
| 145 | + raw_patch = np.array(raw[prefix_box + bb]) |
| 146 | + label_patch = np.array(label) |
| 147 | + |
| 148 | + if self.sampler is not None: |
| 149 | + assert False, "Sampler not implemented" |
| 150 | + # sample_id = 0 |
| 151 | + # while not self.sampler(raw_patch, label_patch): |
| 152 | + # bb = self._sample_bounding_box(shape) |
| 153 | + # raw_patch = np.array(raw[prefix_box + bb]) |
| 154 | + # label_patch = np.array(label[bb]) |
| 155 | + # sample_id += 1 |
| 156 | + # if sample_id > self.max_sampling_attempts: |
| 157 | + # raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") |
| 158 | + |
| 159 | + if have_raw_channels and len(prefix_box) == 0: |
| 160 | + raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width |
| 161 | + |
| 162 | + return raw_patch, label_patch |
| 163 | + |
| 164 | + def __getitem__(self, index): |
| 165 | + raw, labels = self._get_sample(index) |
| 166 | + # initial_label_dtype = labels.dtype |
| 167 | + |
| 168 | + if self.raw_transform is not None: |
| 169 | + raw = self.raw_transform(raw) |
| 170 | + |
| 171 | + if self.label_transform is not None: |
| 172 | + labels = self.label_transform(labels) |
| 173 | + |
| 174 | + if self.transform is not None: |
| 175 | + raw, labels = self.transform(raw, labels) |
| 176 | + |
| 177 | + raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) |
| 178 | + labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) |
| 179 | + return raw, labels |
| 180 | + |
| 181 | + |
| 182 | +if __name__ == "__main__": |
| 183 | + import napari |
| 184 | + |
| 185 | + raw_path = "training_data/images/10.1L_mid_IHCribboncount_5_Z.zarr" |
| 186 | + label_path = "training_data/labels/10.1L_mid_IHCribboncount_5_Z.csv" |
| 187 | + |
| 188 | + f = zarr.open(raw_path, "r") |
| 189 | + raw = f["raw"][:] |
| 190 | + |
| 191 | + labels = process_labels(label_path, shape=raw.shape, sigma=1, eps=1e-7) |
| 192 | + |
| 193 | + v = napari.Viewer() |
| 194 | + v.add_image(raw) |
| 195 | + v.add_image(labels) |
| 196 | + napari.run() |
0 commit comments