diff --git a/batchgeneratorsv2/transforms/intensity/gaussian_noise.py b/batchgeneratorsv2/transforms/intensity/gaussian_noise.py index 54fa523..fb1eb6d 100644 --- a/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +++ b/batchgeneratorsv2/transforms/intensity/gaussian_noise.py @@ -29,20 +29,21 @@ def get_parameters(self, **data_dict) -> dict: def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor: if sum(params['apply_to_channel']) == 0: return img - gaussian_noise = self._sample_gaussian_noise(img.shape, **params) + gaussian_noise = self._sample_gaussian_noise(img, **params) img[params['apply_to_channel']] += gaussian_noise return img - def _sample_gaussian_noise(self, img_shape: Tuple[int, ...], **params): + def _sample_gaussian_noise(self, img: torch.Tensor, **params): + img_shape = img.shape if not isinstance(params['sigmas'], list): num_channels = sum(params['apply_to_channel']) # gaussian = torch.tile(torch.normal(0, params['sigmas'], size=(1, *img_shape[1:])), # (num_channels, *[1]*(len(img_shape) - 1))) - gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:])) + gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:]), device=img.device) gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1))) else: gaussian = [ - torch.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas'] + torch.normal(0, i, size=(1, *img_shape[1:]), device=img.device) for i in params['sigmas'] ] gaussian = torch.cat(gaussian, dim=0) return gaussian diff --git a/batchgeneratorsv2/transforms/noise/gaussian_blur.py b/batchgeneratorsv2/transforms/noise/gaussian_blur.py index 1a707f4..d68b3a2 100644 --- a/batchgeneratorsv2/transforms/noise/gaussian_blur.py +++ b/batchgeneratorsv2/transforms/noise/gaussian_blur.py @@ -64,6 +64,7 @@ def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_ # Apply convolution # remember that weights are [c_out, c_in, ...] + kernel = kernel.to(img_padded.device) img_blurred = conv_op(img_padded[None], kernel.expand(img_padded.shape[0], *[-1] * (kernel.ndim - 1)), groups=img_padded.shape[0])[0] return img_blurred diff --git a/batchgeneratorsv2/transforms/spatial/spatial.py b/batchgeneratorsv2/transforms/spatial/spatial.py index 6da7c17..48550f1 100644 --- a/batchgeneratorsv2/transforms/spatial/spatial.py +++ b/batchgeneratorsv2/transforms/spatial/spatial.py @@ -39,8 +39,8 @@ def __init__(self, """ magnitude must be given in pixels! deformation scale is given as a paercentage of the edge length - - padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target + + padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target because both call self._apply_to_image. Can be "zeros", "reflection", "border" """ super().__init__() @@ -158,6 +158,24 @@ def get_parameters(self, **data_dict) -> dict: 'center_location_in_pixels': center_location_in_pixels } + def prepare_grid(self, img: torch.Tensor, **params) -> torch.Tensor: + grid = _create_centered_identity_grid2(self.patch_size, img.device) + if params['elastic_offsets'] is not None: + grid += params['elastic_offsets'] + if params['affine'] is not None: + grid = torch.matmul(grid, torch.from_numpy(params['affine']).float().to(img.device)) + + # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position + # only do this if we elastic deform + if self.center_deformation and params['elastic_offsets'] is not None: + mn = grid.mean(dim=list(range(img.ndim - 1))) + else: + mn = 0 + + new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])]).to(img.device) + grid += (new_center - mn) + return grid + def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor: if params['affine'] is None and params['elastic_offsets'] is None: # No spatial transformation is being done. Round grid_center and crop without having to interpolate. @@ -181,23 +199,7 @@ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor: pad_kwargs=pad_kwargs) return img else: - grid = _create_centered_identity_grid2(self.patch_size) - - # we deform first, then rotate - if params['elastic_offsets'] is not None: - grid += params['elastic_offsets'] - if params['affine'] is not None: - grid = torch.matmul(grid, torch.from_numpy(params['affine']).float()) - - # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position - # only do this if we elastic deform - if self.center_deformation and params['elastic_offsets'] is not None: - mn = grid.mean(dim=list(range(img.ndim - 1))) - else: - mn = 0 - - new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])]) - grid += (new_center - mn) + grid = self.prepare_grid(img, **params) # print(f'grid sample with pad mode {self.padding_mode_image}') return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None], mode='bilinear', padding_mode=self.padding_mode_image, align_corners=False)[0] @@ -215,24 +217,8 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch. pad_kwargs={'value': 0}) return segmentation else: - grid = _create_centered_identity_grid2(self.patch_size) - - # we deform first, then rotate - if params['elastic_offsets'] is not None: - grid += params['elastic_offsets'] - if params['affine'] is not None: - grid = torch.matmul(grid, torch.from_numpy(params['affine']).float()) - # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate - if self.center_deformation and params['elastic_offsets'] is not None: - mn = grid.mean(dim=list(range(segmentation.ndim - 1))) - else: - mn = 0 - - new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], segmentation.shape[1:])]) - - grid += (new_center - mn) - grid = _convert_my_grid_to_grid_sample_grid(grid, segmentation.shape[1:]) + grid = self.prepare_grid(segmentation, **params) if self.mode_seg == 'nearest': result_seg = grid_sample( @@ -243,7 +229,7 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch. align_corners=False )[0].to(segmentation.dtype) else: - result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype) + result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype, device=segmentation.device) if self.bg_style_seg_sampling: for c in range(segmentation.shape[0]): labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel()))) @@ -269,12 +255,21 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch. align_corners=False )[0][0] >= 0.5] = u else: + # start_event = torch.cuda.Event(enable_timing=True) + # end_event = torch.cuda.Event(enable_timing=True) + # start_event.record() + + # Etienne: this code is quite crazy, why not simply doing a one-hot encoding and then doing grid_sample? for c in range(segmentation.shape[0]): - labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel()))) - #torch.where(torch.bincount(segmentation.ravel()) > 0)[0].to(segmentation.dtype) - tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16) + # labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel()))) + # torch.where(torch.bincount(segmentation.ravel()) > 0)[0].to(segmentation.dtype) + if segmentation.device == torch.device('cpu'): + labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel()))) + else: + labels = torch.sort(torch.unique(segmentation[c]))[0] + tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16, device=segmentation.device) scale_factor = 1000 - done_mask = torch.zeros(*self.patch_size, dtype=torch.bool) + done_mask = torch.zeros(*self.patch_size, dtype=torch.bool, device=segmentation.device) for i, u in enumerate(labels): tmp[i] = grid_sample(((segmentation[c] == u).float() * scale_factor)[None, None], grid[None], mode=self.mode_seg, padding_mode=self.border_mode_seg, align_corners=False)[0][0] @@ -284,6 +279,11 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch. if not torch.all(done_mask): result_seg[c][~done_mask] = labels[tmp[:, ~done_mask].argmax(0)] del tmp + + # end_event.record() + # torch.cuda.synchronize() + # gpu_aug_time = start_event.elapsed_time(end_event) / 1000 # Convert ms to seconds + # print(f"{gpu_aug_time:.3f}s for spatial augmentation") del grid return result_seg.contiguous() @@ -341,8 +341,8 @@ def create_affine_matrix_2d(rotation_angle, scaling_factors): # return grid -def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor: - space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size] +def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]], device: torch.device) -> torch.Tensor: + space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s).to(device) for s in size] grid = torch.meshgrid(space, indexing="ij") grid = torch.stack(grid, -1) return grid