Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions batchgeneratorsv2/transforms/intensity/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,21 @@ def get_parameters(self, **data_dict) -> dict:
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
if sum(params['apply_to_channel']) == 0:
return img
gaussian_noise = self._sample_gaussian_noise(img.shape, **params)
gaussian_noise = self._sample_gaussian_noise(img, **params)
img[params['apply_to_channel']] += gaussian_noise
return img

def _sample_gaussian_noise(self, img_shape: Tuple[int, ...], **params):
def _sample_gaussian_noise(self, img: torch.Tensor, **params):
img_shape = img.shape
if not isinstance(params['sigmas'], list):
num_channels = sum(params['apply_to_channel'])
# gaussian = torch.tile(torch.normal(0, params['sigmas'], size=(1, *img_shape[1:])),
# (num_channels, *[1]*(len(img_shape) - 1)))
gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:]))
gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:]), device=img.device)
gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1)))
else:
gaussian = [
torch.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas']
torch.normal(0, i, size=(1, *img_shape[1:]), device=img.device) for i in params['sigmas']
]
gaussian = torch.cat(gaussian, dim=0)
return gaussian
Expand Down
1 change: 1 addition & 0 deletions batchgeneratorsv2/transforms/noise/gaussian_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_

# Apply convolution
# remember that weights are [c_out, c_in, ...]
kernel = kernel.to(img_padded.device)
img_blurred = conv_op(img_padded[None], kernel.expand(img_padded.shape[0], *[-1] * (kernel.ndim - 1)), groups=img_padded.shape[0])[0]
return img_blurred

Expand Down
86 changes: 43 additions & 43 deletions batchgeneratorsv2/transforms/spatial/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(self,
"""
magnitude must be given in pixels!
deformation scale is given as a paercentage of the edge length
padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target

padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target
because both call self._apply_to_image. Can be "zeros", "reflection", "border"
"""
super().__init__()
Expand Down Expand Up @@ -158,6 +158,24 @@ def get_parameters(self, **data_dict) -> dict:
'center_location_in_pixels': center_location_in_pixels
}

def prepare_grid(self, img: torch.Tensor, **params) -> torch.Tensor:
grid = _create_centered_identity_grid2(self.patch_size, img.device)
if params['elastic_offsets'] is not None:
grid += params['elastic_offsets']
if params['affine'] is not None:
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float().to(img.device))

# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
# only do this if we elastic deform
if self.center_deformation and params['elastic_offsets'] is not None:
mn = grid.mean(dim=list(range(img.ndim - 1)))
else:
mn = 0

new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])]).to(img.device)
grid += (new_center - mn)
return grid

def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
if params['affine'] is None and params['elastic_offsets'] is None:
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
Expand All @@ -181,23 +199,7 @@ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
pad_kwargs=pad_kwargs)
return img
else:
grid = _create_centered_identity_grid2(self.patch_size)

# we deform first, then rotate
if params['elastic_offsets'] is not None:
grid += params['elastic_offsets']
if params['affine'] is not None:
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())

# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
# only do this if we elastic deform
if self.center_deformation and params['elastic_offsets'] is not None:
mn = grid.mean(dim=list(range(img.ndim - 1)))
else:
mn = 0

new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
grid += (new_center - mn)
grid = self.prepare_grid(img, **params)
# print(f'grid sample with pad mode {self.padding_mode_image}')
return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
mode='bilinear', padding_mode=self.padding_mode_image, align_corners=False)[0]
Expand All @@ -215,24 +217,8 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.
pad_kwargs={'value': 0})
return segmentation
else:
grid = _create_centered_identity_grid2(self.patch_size)

# we deform first, then rotate
if params['elastic_offsets'] is not None:
grid += params['elastic_offsets']
if params['affine'] is not None:
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())

# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
if self.center_deformation and params['elastic_offsets'] is not None:
mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
else:
mn = 0

new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], segmentation.shape[1:])])

grid += (new_center - mn)
grid = _convert_my_grid_to_grid_sample_grid(grid, segmentation.shape[1:])
grid = self.prepare_grid(segmentation, **params)

if self.mode_seg == 'nearest':
result_seg = grid_sample(
Expand All @@ -243,7 +229,7 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.
align_corners=False
)[0].to(segmentation.dtype)
else:
result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype)
result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype, device=segmentation.device)
if self.bg_style_seg_sampling:
for c in range(segmentation.shape[0]):
labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
Expand All @@ -269,12 +255,21 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.
align_corners=False
)[0][0] >= 0.5] = u
else:
# start_event = torch.cuda.Event(enable_timing=True)
# end_event = torch.cuda.Event(enable_timing=True)
# start_event.record()

# Etienne: this code is quite crazy, why not simply doing a one-hot encoding and then doing grid_sample?
for c in range(segmentation.shape[0]):
labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
#torch.where(torch.bincount(segmentation.ravel()) > 0)[0].to(segmentation.dtype)
tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16)
# labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
# torch.where(torch.bincount(segmentation.ravel()) > 0)[0].to(segmentation.dtype)
if segmentation.device == torch.device('cpu'):
labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
else:
labels = torch.sort(torch.unique(segmentation[c]))[0]
tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16, device=segmentation.device)
scale_factor = 1000
done_mask = torch.zeros(*self.patch_size, dtype=torch.bool)
done_mask = torch.zeros(*self.patch_size, dtype=torch.bool, device=segmentation.device)
for i, u in enumerate(labels):
tmp[i] = grid_sample(((segmentation[c] == u).float() * scale_factor)[None, None], grid[None],
mode=self.mode_seg, padding_mode=self.border_mode_seg, align_corners=False)[0][0]
Expand All @@ -284,6 +279,11 @@ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.
if not torch.all(done_mask):
result_seg[c][~done_mask] = labels[tmp[:, ~done_mask].argmax(0)]
del tmp

# end_event.record()
# torch.cuda.synchronize()
# gpu_aug_time = start_event.elapsed_time(end_event) / 1000 # Convert ms to seconds
# print(f"{gpu_aug_time:.3f}s for spatial augmentation")
del grid
return result_seg.contiguous()

Expand Down Expand Up @@ -341,8 +341,8 @@ def create_affine_matrix_2d(rotation_angle, scaling_factors):
# return grid


def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]], device: torch.device) -> torch.Tensor:
space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s).to(device) for s in size]
grid = torch.meshgrid(space, indexing="ij")
grid = torch.stack(grid, -1)
return grid
Expand Down