diff --git a/batchgenerators/augmentations/color_augmentations.py b/batchgenerators/augmentations/color_augmentations.py index 05f483b..3e940dc 100644 --- a/batchgenerators/augmentations/color_augmentations.py +++ b/batchgenerators/augmentations/color_augmentations.py @@ -13,64 +13,66 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range -from typing import Tuple, Union, Callable +from typing import Tuple, Callable, Union import numpy as np -from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter +from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter, get_broadcast_axes, \ + reverse_broadcast + + +def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Callable[[], float]], + per_channel: bool, + size: int, + broadcast_size: int): + if per_channel: + factor = [] + for _ in range(size): + if callable(contrast_range): + factor.append(contrast_range()) + elif contrast_range[0] < 1 and np.random.random() < 0.5: + factor.append(np.random.uniform(contrast_range[0], 1)) + else: + factor.append(np.random.uniform(max(contrast_range[0], 1), contrast_range[1])) -def augment_contrast(data_sample: np.ndarray, - contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), - preserve_range: bool = True, - per_channel: bool = True, - p_per_channel: float = 1) -> np.ndarray: - if not per_channel: + factor = reverse_broadcast(np.array(factor), get_broadcast_axes(broadcast_size)) + else: if callable(contrast_range): factor = contrast_range() + elif contrast_range[0] < 1 and np.random.random() < 0.5: + factor = np.random.uniform(contrast_range[0], 1) else: - if np.random.random() < 0.5 and contrast_range[0] < 1: - factor = np.random.uniform(contrast_range[0], 1) - else: - factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) + factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) - for c in range(data_sample.shape[0]): - if np.random.uniform() < p_per_channel: - mn = data_sample[c].mean() - if preserve_range: - minm = data_sample[c].min() - maxm = data_sample[c].max() + return factor - data_sample[c] = (data_sample[c] - mn) * factor + mn - if preserve_range: - data_sample[c][data_sample[c] < minm] = minm - data_sample[c][data_sample[c] > maxm] = maxm - else: - for c in range(data_sample.shape[0]): - if np.random.uniform() < p_per_channel: - if callable(contrast_range): - factor = contrast_range() - else: - if np.random.random() < 0.5 and contrast_range[0] < 1: - factor = np.random.uniform(contrast_range[0], 1) - else: - factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) - - mn = data_sample[c].mean() - if preserve_range: - minm = data_sample[c].min() - maxm = data_sample[c].max() - - data_sample[c] = (data_sample[c] - mn) * factor + mn - - if preserve_range: - data_sample[c][data_sample[c] < minm] = minm - data_sample[c][data_sample[c] > maxm] = maxm +def augment_contrast(data_sample: np.ndarray, + contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), + preserve_range: bool = True, + per_channel: bool = True, + p_per_channel: float = 1, + batched=False) -> np.ndarray: + mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel + if np.any(mask): + workon = data_sample[mask] + factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon), workon.ndim) + axes = tuple(range(1, workon.ndim)) + mean = workon.mean(axis=axes, keepdims=True) + if preserve_range: + minm = workon.min(axis=axes, keepdims=True) + maxm = workon.max(axis=axes, keepdims=True) + + data_sample[mask] = workon * factor + mean * (1 - factor) # writing directly in data_sample + + if preserve_range: + np.clip(data_sample[mask], minm, maxm, out=data_sample[mask]) + return data_sample -def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel:bool=True, p_per_channel:float=1.): +def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channel: bool = True, + p_per_channel: float = 1.): """ data_sample must have shape (c, x, y(, z))) :param data_sample: @@ -80,27 +82,29 @@ def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel :param p_per_channel: :return: """ - if not per_channel: - rnd_nb = np.random.normal(mu, sigma) - for c in range(data_sample.shape[0]): - if np.random.uniform() <= p_per_channel: - data_sample[c] += rnd_nb + size = data_sample.shape[0] + if per_channel: + rnd_nb = np.random.normal(mu, sigma, size=size) else: - for c in range(data_sample.shape[0]): - if np.random.uniform() <= p_per_channel: - rnd_nb = np.random.normal(mu, sigma) - data_sample[c] += rnd_nb + rnd_nb = np.repeat(np.random.normal(mu, sigma), size) + rnd_nb[np.random.uniform(size=size) > p_per_channel] = 0.0 + data_sample += reverse_broadcast(rnd_nb, get_broadcast_axes(data_sample.ndim)) return data_sample -def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True): - multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) - if not per_channel: - data_sample *= multiplier - else: - for c in range(data_sample.shape[0]): - multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) - data_sample[c] *= multiplier +def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int, ...]): + if per_channel: + if batched: + return shape[:2] + (1,) * (len(shape) - 2) + return (shape[0],) + (1,) * (len(shape) - 1) + if batched: + return (shape[0],) + (1,) * (len(shape) - 1) + return (1,) * len(shape) + + +def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False): + size = setup_augment_brightness_multiplicative(per_channel, batched, data_sample.shape) + data_sample *= np.random.uniform(multiplier_range[0], multiplier_range[1], size=size) return data_sample @@ -110,38 +114,55 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon data_sample = - data_sample if not per_channel: - retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats - if retain_stats_here: + retain_stats = retain_stats() if callable(retain_stats) else retain_stats + if retain_stats: mn = data_sample.mean() sd = data_sample.std() - if np.random.random() < 0.5 and gamma_range[0] < 1: + if gamma_range[0] < 1 and np.random.random() < 0.5: gamma = np.random.uniform(gamma_range[0], 1) else: gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) minm = data_sample.min() rnge = data_sample.max() - minm data_sample = np.power(((data_sample - minm) / float(rnge + epsilon)), gamma) * rnge + minm - if retain_stats_here: - data_sample = data_sample - data_sample.mean() - data_sample = data_sample / (data_sample.std() + 1e-8) * sd - data_sample = data_sample + mn + if retain_stats: + data_sample -= data_sample.mean() + data_sample *= sd / (data_sample.std() + 1e-8) + data_sample += mn else: - for c in range(data_sample.shape[0]): - retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats - if retain_stats_here: - mn = data_sample[c].mean() - sd = data_sample[c].std() - if np.random.random() < 0.5 and gamma_range[0] < 1: - gamma = np.random.uniform(gamma_range[0], 1) + shape_0 = data_sample.shape[0] + gamma = [] + gamma_l = max(gamma_range[0], 1) + for i in range(shape_0): + if gamma_range[0] < 1 and np.random.random() < 0.5: + gamma.append(np.random.uniform(gamma_range[0], 1)) else: - gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) - minm = data_sample[c].min() - rnge = data_sample[c].max() - minm - data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm - if retain_stats_here: - data_sample[c] = data_sample[c] - data_sample[c].mean() - data_sample[c] = data_sample[c] / (data_sample[c].std() + 1e-8) * sd - data_sample[c] = data_sample[c] + mn + gamma.append(np.random.uniform(gamma_l, gamma_range[1])) + gamma = np.array(gamma) + + axes = tuple(range(1, data_sample.ndim)) + + if callable(retain_stats): + retain_stats = [retain_stats() for _ in range(shape_0)] + else: + retain_stats = [retain_stats] * shape_0 + retain_stats_here = any(retain_stats) + if retain_stats_here: + mn = data_sample[retain_stats].mean(axis=axes, keepdims=True) + sd = data_sample[retain_stats].mean(axis=axes, keepdims=True) + + minm = data_sample.min(axis=axes, keepdims=True) + rnge = data_sample.max(axis=axes, keepdims=True) - minm + epsilon + + broadcast_axes = get_broadcast_axes(data_sample.ndim) + gamma = reverse_broadcast(gamma, broadcast_axes) + data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm + + if retain_stats_here: + data_sample[retain_stats] -= data_sample[retain_stats].mean(axis=axes, keepdims=True) + data_sample[retain_stats] *= sd / (data_sample[retain_stats].std(axis=axes, keepdims=True) + 1e-8) + data_sample[retain_stats] += mn + if invert_image: data_sample = - data_sample return data_sample diff --git a/batchgenerators/augmentations/crop_and_pad_augmentations.py b/batchgenerators/augmentations/crop_and_pad_augmentations.py index 88a8cff..1ba4f0a 100644 --- a/batchgenerators/augmentations/crop_and_pad_augmentations.py +++ b/batchgenerators/augmentations/crop_and_pad_augmentations.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range import numpy as np from batchgenerators.augmentations.utils import pad_nd_image +from typing import Union, Sequence def center_crop(data, crop_size, seg=None): @@ -30,28 +30,24 @@ def get_lbs_for_random_crop(crop_size, data_shape, margins): :param margins: :return: """ - lbs = [] - for i in range(len(data_shape) - 2): - if data_shape[i+2] - crop_size[i] - margins[i] > margins[i]: - lbs.append(np.random.randint(margins[i], data_shape[i+2] - crop_size[i] - margins[i])) - else: - lbs.append((data_shape[i+2] - crop_size[i]) // 2) - return lbs + new_shape = data_shape - crop_size + mask = new_shape > 2 * margins + new_shape[mask] = np.random.randint(margins[mask], new_shape[mask] - margins[mask]) + new_shape[~mask] //= 2 + return new_shape def get_lbs_for_center_crop(crop_size, data_shape): """ :param crop_size: - :param data_shape: (b,c,x,y(,z)) must be the whole thing! + :param data_shape: (b,c,x,y(,z)) must be the only x,y(,z)! :return: """ - lbs = [] - for i in range(len(data_shape) - 2): - lbs.append((data_shape[i + 2] - crop_size[i]) // 2) - return lbs + return (data_shape - crop_size) // 2 -def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", +def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.ndarray], np.ndarray] = None, + crop_size=128, margins=(0, 0, 0), crop_type="center", pad_mode='constant', pad_kwargs={'constant_values': 0}, pad_mode_seg='constant', pad_kwargs_seg={'constant_values': 0}): """ @@ -69,44 +65,39 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", :param crop_type: random or center :return: """ - if not isinstance(data, (list, tuple, np.ndarray)): - raise TypeError("data has to be either a numpy array or a list") - - data_shape = tuple([len(data)] + list(data[0].shape)) + data_shape = (len(data),) + data[0].shape data_dtype = data[0].dtype dim = len(data_shape) - 2 if seg is not None: - seg_shape = tuple([len(seg)] + list(seg[0].shape)) + seg_shape = (len(seg),) + seg[0].shape seg_dtype = seg[0].dtype - if not isinstance(seg, (list, tuple, np.ndarray)): - raise TypeError("data has to be either a numpy array or a list") - - assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), "data and seg must have the same spatial " \ - "dimensions. Data: %s, seg: %s" % \ - (str(data_shape), str(seg_shape)) + assert np.array_equal(seg_shape[2:], data_shape[2:]), "data and seg must have the same spatial dimensions. " \ + f"Data: {data_shape}, seg: {seg_shape}" if type(crop_size) not in (tuple, list, np.ndarray): - crop_size = [crop_size] * dim + crop_size = (crop_size,) * dim else: - assert len(crop_size) == len( - data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \ - "data (2d/3d)" + assert len(crop_size) == dim, ("If you provide a list/tuple as center crop make sure it has the same dimension " + "as your data (2d/3d)") + crop_size = np.asarray(crop_size) if not isinstance(margins, (np.ndarray, tuple, list)): - margins = [margins] * dim + margins = (margins,) * dim + margins = np.asarray(margins) - data_return = np.zeros([data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype) + data_return = np.zeros((data_shape[0], data_shape[1], *crop_size), dtype=data_dtype) if seg is not None: - seg_return = np.zeros([seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype) + seg_return = np.zeros((seg_shape[0], seg_shape[1], *crop_size), dtype=seg_dtype) else: seg_return = None for b in range(data_shape[0]): - data_shape_here = [data_shape[0]] + list(data[b].shape) + data_first_dim = data[b].shape[0] + data_shape_here = np.array(data[b].shape[1:]) if seg is not None: - seg_shape_here = [seg_shape[0]] + list(seg[b].shape) + seg_first_dim = seg[b].shape[0] if crop_type == "center": lbs = get_lbs_for_center_crop(crop_size, data_shape_here) @@ -115,22 +106,25 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", else: raise NotImplementedError("crop_type must be either center or random") - need_to_pad = [[0, 0]] + [[abs(min(0, lbs[d])), - abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d])))] - for d in range(dim)] + zero = np.zeros(dim, dtype=int) + temp1 = np.abs(np.minimum(lbs, zero)) + lbs_plus_crop_size = lbs + crop_size + temp2 = np.abs(np.minimum(zero, data_shape_here - lbs_plus_crop_size)) + need_to_pad = ((0, 0),) + tuple(zip(temp1, temp2)) + need_to_pad = np.array(need_to_pad) # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed - ubs = [min(lbs[d] + crop_size[d], data_shape_here[d+2]) for d in range(dim)] - lbs = [max(0, lbs[d]) for d in range(dim)] + ubs = np.minimum(data_shape_here, lbs_plus_crop_size) + lbs = np.maximum(zero, lbs) - slicer_data = [slice(0, data_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] - data_cropped = data[b][tuple(slicer_data)] + slicer_data = (slice(0, data_first_dim), *[slice(lbs[d], ubs[d]) for d in range(dim)]) + data_cropped = data[b][slicer_data] if seg_return is not None: - slicer_seg = [slice(0, seg_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] - seg_cropped = seg[b][tuple(slicer_seg)] + slicer_data = (slice(0, seg_first_dim),) + slicer_data[1:] + seg_cropped = seg[b][slicer_data] - if any([i > 0 for j in need_to_pad for i in j]): + if np.any(need_to_pad): data_return[b] = np.pad(data_cropped, need_to_pad, pad_mode, **pad_kwargs) if seg_return is not None: seg_return[b] = np.pad(seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg) diff --git a/batchgenerators/augmentations/noise_augmentations.py b/batchgenerators/augmentations/noise_augmentations.py index c97e395..add8b1f 100644 --- a/batchgenerators/augmentations/noise_augmentations.py +++ b/batchgenerators/augmentations/noise_augmentations.py @@ -17,8 +17,7 @@ from typing import Tuple import numpy as np -from batchgenerators.augmentations.utils import get_range_val, mask_random_squares -from builtins import range +from batchgenerators.augmentations.utils import mask_random_squares, uniform from scipy.ndimage import gaussian_filter @@ -30,43 +29,55 @@ def augment_rician_noise(data_sample, noise_variance=(0, 0.1)): return data_sample -def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), - p_per_channel: float = 1, per_channel: bool = False) -> np.ndarray: +def setup_augment_gaussian_noise(noise_variance: Tuple[float, float], per_channel: bool, size: int): if not per_channel: variance = noise_variance[0] if noise_variance[0] == noise_variance[1] else \ random.uniform(noise_variance[0], noise_variance[1]) + variance = np.array((variance,) * size) else: - variance = None - for c in range(data_sample.shape[0]): - if np.random.uniform() < p_per_channel: - # lol good luck reading this - variance_here = variance if variance is not None else \ - noise_variance[0] if noise_variance[0] == noise_variance[1] else \ - random.uniform(noise_variance[0], noise_variance[1]) - # bug fixed: https://github.com/MIC-DKFZ/batchgenerators/issues/86 - data_sample[c] = data_sample[c] + np.random.normal(0.0, variance_here, size=data_sample[c].shape) + variance = np.array((noise_variance[0],) * size) if noise_variance[0] == noise_variance[1] else \ + np.random.uniform(noise_variance[0], noise_variance[1], size=size) + return variance + + +def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), + p_per_channel: float = 1, per_channel: bool = False, batched: bool = False) -> np.ndarray: + mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel + size = np.count_nonzero(mask) + if size: + variance = setup_augment_gaussian_noise(noise_variance, per_channel, size) + data_sample[mask] += np.random.normal(0.0, variance, data_sample[mask].T.shape).T + return data_sample def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, float], per_channel: bool = True, p_per_channel: float = 1, different_sigma_per_axis: bool = False, p_isotropic: float = 0) -> np.ndarray: + # TODO: Vectorize per channel (gaussian_filter accepts axes) if not per_channel: # Godzilla Had a Stroke Trying to Read This and F***ing Died # https://i.kym-cdn.com/entries/icons/original/000/034/623/Untitled-3.png - sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or - ((np.random.uniform() < p_isotropic) and - different_sigma_per_axis)) \ - else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + # sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or + # ((np.random.uniform() < p_isotropic) and + # different_sigma_per_axis)) \ + # else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + + # Godzilla revived + if not different_sigma_per_axis or np.random.uniform() < p_isotropic: + sigma = uniform(sigma_range[0], sigma_range[1]) + else: + sigma = [uniform(sigma_range[0], sigma_range[1]) for _ in data_sample.shape[1:]] else: sigma = None for c in range(data_sample.shape[0]): if np.random.uniform() <= p_per_channel: if per_channel: - sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or - ((np.random.uniform() < p_isotropic) and - different_sigma_per_axis)) \ - else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] + if not different_sigma_per_axis or np.random.uniform() < p_isotropic: + sigma = uniform(sigma_range[0], sigma_range[1]) + else: + sigma = [uniform(sigma_range[0], sigma_range[1]) for _ in data_sample.shape[1:]] + data_sample[c] = gaussian_filter(data_sample[c], sigma, order=0) return data_sample @@ -74,8 +85,8 @@ def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, flo def augment_blank_square_noise(data_sample, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None): # rnd_n_val = get_range_val(noise_val) - rnd_square_size = get_range_val(square_size) - rnd_n_squares = get_range_val(n_squares) + rnd_square_size = uniform(square_size[0], square_size[1]) + rnd_n_squares = uniform(n_squares[0], n_squares[1]) data_sample = mask_random_squares(data_sample, square_size=rnd_square_size, n_squares=rnd_n_squares, n_val=noise_val, channel_wise_n_val=channel_wise_n_val, diff --git a/batchgenerators/augmentations/normalizations.py b/batchgenerators/augmentations/normalizations.py index 20a6d65..2f56a99 100644 --- a/batchgenerators/augmentations/normalizations.py +++ b/batchgenerators/augmentations/normalizations.py @@ -17,81 +17,70 @@ def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): - data_normalized = np.zeros(data.shape, dtype=data.dtype) - for b in range(data.shape[0]): - if per_channel: - for c in range(data.shape[1]): - data_normalized[b, c] = min_max_normalization(data[b, c], eps) - else: - data_normalized[b] = min_max_normalization(data[b], eps) + if per_channel: + axes = tuple(range(2, data.ndim)) + else: + axes = tuple(range(1, data.ndim)) + data_normalized = min_max_normalization_batched(data, eps, axes) data_normalized *= (rnge[1] - rnge[0]) data_normalized += rnge[0] return data_normalized +def min_max_normalization_batched(data, eps, axes): + mn = data.min(axis=axes, keepdims=True) + mx = data.max(axis=axes, keepdims=True) + old_range = mx - mn + eps + + data_normalized = (data - mn) / old_range + return data_normalized + + def min_max_normalization(data, eps): mn = data.min() mx = data.max() - data_normalized = data - mn old_range = mx - mn + eps - data_normalized /= old_range - + data_normalized = (data - mn) / old_range return data_normalized + def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): - data_normalized = np.zeros(data.shape, dtype=data.dtype) - for b in range(data.shape[0]): - if per_channel: - for c in range(data.shape[1]): - mean = data[b, c].mean() - std = data[b, c].std() + epsilon - data_normalized[b, c] = (data[b, c] - mean) / std - else: - mean = data[b].mean() - std = data[b].std() + epsilon - data_normalized[b] = (data[b] - mean) / std + if per_channel: + axes = tuple(range(2, data.ndim)) + else: + axes = tuple(range(1, data.ndim)) + + mean = np.mean(data, axis=axes, keepdims=True) + std = np.std(data, axis=axes, keepdims=True) + epsilon + data_normalized = (data - mean) / std return data_normalized def mean_std_normalization(data, mean, std, per_channel=True): - data_normalized = np.zeros(data.shape, dtype=data.dtype) - if isinstance(data, np.ndarray): - data_shape = tuple(list(data.shape)) - elif isinstance(data, (list, tuple)): - assert len(data) > 0 and isinstance(data[0], np.ndarray) - data_shape = [len(data)] + list(data[0].shape) - else: - raise TypeError("Data has to be either a numpy array or a list") - - if per_channel and isinstance(mean, float) and isinstance(std, float): - mean = [mean] * data_shape[1] - std = [std] * data_shape[1] - elif per_channel and isinstance(mean, (tuple, list, np.ndarray)): - assert len(mean) == data_shape[1] - elif per_channel and isinstance(std, (tuple, list, np.ndarray)): - assert len(std) == data_shape[1] - - for b in range(data_shape[0]): - if per_channel: - for c in range(data_shape[1]): - data_normalized[b][c] = (data[b][c] - mean[c]) / std[c] + if per_channel: + channel_dimension = data[0].shape[0] + if isinstance(mean, float) and isinstance(std, float): + mean = (mean,) * channel_dimension + std = (std,) * channel_dimension else: - data_normalized[b] = (data[b] - mean) / std + assert len(mean) == channel_dimension + assert len(std) == channel_dimension + + broadcast_axes = tuple(range(2, data.ndim)) + mean = np.expand_dims(np.broadcast_to(mean, (len(data), len(mean))), broadcast_axes) + std = np.expand_dims(np.broadcast_to(std, (len(data), len(std))), broadcast_axes) + + data_normalized = (data - mean) / std return data_normalized def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False): - for b in range(len(data)): - if not per_channel: - cut_off_lower = np.percentile(data[b], percentile_lower) - cut_off_upper = np.percentile(data[b], percentile_upper) - data[b][data[b] < cut_off_lower] = cut_off_lower - data[b][data[b] > cut_off_upper] = cut_off_upper - else: - for c in range(data.shape[1]): - cut_off_lower = np.percentile(data[b, c], percentile_lower) - cut_off_upper = np.percentile(data[b, c], percentile_upper) - data[b, c][data[b, c] < cut_off_lower] = cut_off_lower - data[b, c][data[b, c] > cut_off_upper] = cut_off_upper + if per_channel: + axes = tuple(range(2, data.ndim)) + else: + axes = tuple(range(1, data.ndim)) + + cut_off_lower, cut_off_upper = np.percentile(data, (percentile_lower, percentile_upper), axis=axes, keepdims=True) + np.clip(data, cut_off_lower, cut_off_upper, out=data) return data diff --git a/batchgenerators/augmentations/resample_augmentations.py b/batchgenerators/augmentations/resample_augmentations.py index 4b0d8a0..955e43d 100644 --- a/batchgenerators/augmentations/resample_augmentations.py +++ b/batchgenerators/augmentations/resample_augmentations.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range import numpy as np -import random from skimage.transform import resize from batchgenerators.augmentations.utils import uniform @@ -50,9 +48,6 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan ignore_axes: tuple/list ''' - if not isinstance(zoom_range, (list, tuple, np.ndarray)): - zoom_range = [zoom_range] - shp = np.array(data_sample.shape[1:]) dim = len(shp) @@ -63,14 +58,14 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan else: zoom = uniform(zoom_range[0], zoom_range[1]) - target_shape = np.round(shp * zoom).astype(int) + target_shape = np.rint(shp * zoom).astype(int) if ignore_axes is not None: for i in ignore_axes: target_shape[i] = shp[i] if channels is None: - channels = list(range(data_sample.shape[0])) + channels = range(data_sample.shape[0]) for c in channels: if np.random.uniform() < p_per_channel: @@ -81,15 +76,14 @@ def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_chan else: zoom = uniform(zoom_range[0], zoom_range[1]) - target_shape = np.round(shp * zoom).astype(int) + target_shape = np.rint(shp * zoom).astype(int) if ignore_axes is not None: for i in ignore_axes: target_shape[i] = shp[i] - downsampled = resize(data_sample[c].astype(float), target_shape, order=order_downsample, mode='edge', - anti_aliasing=False) - data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge', - anti_aliasing=False) + downsampled = resize(data_sample[c].astype(float, copy=False), target_shape, order=order_downsample, + mode='edge', anti_aliasing=False) + data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge', anti_aliasing=False) return data_sample diff --git a/batchgenerators/augmentations/spatial_transformations.py b/batchgenerators/augmentations/spatial_transformations.py index 7902f55..be97738 100644 --- a/batchgenerators/augmentations/spatial_transformations.py +++ b/batchgenerators/augmentations/spatial_transformations.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from builtins import range - import numpy as np from scipy.ndimage import map_coordinates from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \ @@ -38,7 +36,7 @@ def augment_rot90(sample_data, sample_seg, num_rot=(1, 2, 3), axes=(0, 1, 2)): num_rot = np.random.choice(num_rot) axes = np.random.choice(axes, size=2, replace=False) axes.sort() - axes = [i + 1 for i in axes] + axes += 1 sample_data = np.rot90(sample_data, num_rot, axes) if sample_seg is not None: sample_seg = np.rot90(sample_seg, num_rot, axes) @@ -58,18 +56,18 @@ def augment_resize(sample_data, sample_seg, target_size, order=3, order_seg=1): np.ndarray (just like data). Must also be (c, x, y(, z)) :return: """ - dimensionality = len(sample_data.shape) - 1 + dimensionality = sample_data.ndim - 1 if not isinstance(target_size, (list, tuple)): - target_size_here = [target_size] * dimensionality + target_size_here = (target_size,) * dimensionality else: assert len(target_size) == dimensionality, "If you give a tuple/list as target size, make sure it has " \ "the same dimensionality as data!" - target_size_here = list(target_size) + target_size_here = tuple(target_size) sample_data = resize_multichannel_image(sample_data, target_size_here, order) if sample_seg is not None: - target_seg = np.ones([sample_seg.shape[0]] + target_size_here) + target_seg = np.ones((sample_seg.shape[0],) + target_size_here) for c in range(sample_seg.shape[0]): target_seg[c] = resize_segmentation(sample_seg[c], target_size_here, order_seg) else: @@ -92,7 +90,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): :return: """ - dimensionality = len(sample_data.shape) - 1 + dimensionality = sample_data.ndim - 1 shape = np.array(sample_data.shape[1:]) if not isinstance(zoom_factors, (list, tuple)): zoom_factors_here = np.array([zoom_factors] * dimensionality) @@ -100,25 +98,48 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1): assert len(zoom_factors) == dimensionality, "If you give a tuple/list as target size, make sure it has " \ "the same dimensionality as data!" zoom_factors_here = np.array(zoom_factors) - target_shape_here = list(np.round(shape * zoom_factors_here).astype(int)) + target_shape_here = tuple(np.rint(shape * zoom_factors_here).astype(int, copy=False)) sample_data = resize_multichannel_image(sample_data, target_shape_here, order) if sample_seg is not None: - target_seg = np.ones([sample_seg.shape[0]] + target_shape_here) - for c in range(sample_seg.shape[0]): - target_seg[c] = resize_segmentation(sample_seg[c], target_shape_here, order_seg) + target_seg = np.array([ + resize_segmentation(sample_seg[c], target_shape_here, order_seg) for c in range(sample_seg.shape[0])]) else: target_seg = None return sample_data, target_seg +def augment_mirroring_batched(sample_data, sample_seg=None, axes=(0, 1, 2)): + assert sample_data.ndim == 5 or sample_data.ndim == 4, \ + "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " \ + "[batch, channels, x, y] or [batch, channels, x, y, z]" + size = len(sample_data) + has_sample_seg = sample_seg is not None + if 0 in axes: + mask = np.random.uniform(size=size) < 0.5 + sample_data[mask] = np.flip(sample_data[mask], 2) + if has_sample_seg: + sample_seg[mask] = np.flip(sample_seg[mask], 2) + if 1 in axes: + mask = np.random.uniform(size=size) < 0.5 + sample_data[mask] = np.flip(sample_data[mask], 3) + if has_sample_seg: + sample_seg[mask] = np.flip(sample_seg[mask], 3) + if 2 in axes and sample_data.ndim == 5: + mask = np.random.uniform(size=size) < 0.5 + sample_data[mask] = np.flip(sample_data[mask], 4) + if has_sample_seg: + sample_seg[mask] = np.flip(sample_seg[mask], 4) + return sample_data, sample_seg + + def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): - if (len(sample_data.shape) != 3) and (len(sample_data.shape) != 4): - raise Exception( - "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " - "[channels, x, y] or [channels, x, y, z]") + sample_data = np.expand_dims(sample_data, 0) + if sample_seg is not None: + sample_seg = np.expand_dims(sample_seg, 0) + return augment_mirroring_batched(sample_data, sample_seg, axes) if 0 in axes and np.random.uniform() < 0.5: sample_data[:, :] = sample_data[:, ::-1] if sample_seg is not None: @@ -127,7 +148,7 @@ def augment_mirroring(sample_data, sample_seg=None, axes=(0, 1, 2)): sample_data[:, :, :] = sample_data[:, :, ::-1] if sample_seg is not None: sample_seg[:, :, :] = sample_seg[:, :, ::-1] - if 2 in axes and len(sample_data.shape) == 4: + if 2 in axes and sample_data.ndim == 4: if np.random.uniform() < 0.5: sample_data[:, :, :, :] = sample_data[:, :, :, ::-1] if sample_seg is not None: @@ -196,20 +217,13 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, dim = len(patch_size) seg_result = None if seg is not None: - if dim == 2: - seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) - else: - seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]), - dtype=np.float32) + seg_result = np.zeros((seg.shape[0], seg.shape[1], *patch_size), dtype=np.float32) - if dim == 2: - data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32) - else: - data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]), - dtype=np.float32) + data_result = np.zeros((data.shape[0], data.shape[1], *patch_size), dtype=np.float32) if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)): - patch_center_dist_from_border = dim * [patch_center_dist_from_border] + patch_center_dist_from_border = (patch_center_dist_from_border,) * dim + patch_center_dist_from_border = np.asarray(patch_center_dist_from_border) for sample_id in range(data.shape[0]): coords = create_zero_centered_coordinate_mesh(patch_size) @@ -247,13 +261,14 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, if do_scale and np.random.uniform() < p_scale_per_sample: if independent_scale_for_each_axis and np.random.uniform() < p_independent_scale_per_axis: sc = [] + scale_l = max(scale[0], 1) for _ in range(dim): - if np.random.random() < 0.5 and scale[0] < 1: + if scale[0] < 1 and np.random.random() < 0.5: sc.append(np.random.uniform(scale[0], 1)) else: - sc.append(np.random.uniform(max(scale[0], 1), scale[1])) + sc.append(np.random.uniform(scale_l, scale[1])) else: - if np.random.random() < 0.5 and scale[0] < 1: + if scale[0] < 1 and np.random.random() < 0.5: sc = np.random.uniform(scale[0], 1) else: sc = np.random.uniform(max(scale[0], 1), scale[1]) @@ -263,13 +278,17 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, # now find a nice center location if modified_coords: + data_shape_here = np.array(data.shape[2:]) + if random_crop: + ctr = np.random.uniform(patch_center_dist_from_border, data_shape_here - patch_center_dist_from_border) + else: + ctr = data_shape_here / 2. - 0.5 + for d in range(dim): - if random_crop: - ctr = np.random.uniform(patch_center_dist_from_border[d], - data.shape[d + 2] - patch_center_dist_from_border[d]) - else: - ctr = data.shape[d + 2] / 2. - 0.5 - coords[d] += ctr + coords[d] += ctr[d] + # vectorized version, seems a bit slower + # coords += reverse_broadcast(ctr, get_broadcast_axes(coords.ndim)) + for channel_id in range(data.shape[1]): data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data, border_mode_data, cval=border_cval_data) @@ -284,7 +303,7 @@ def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30, else: s = seg[sample_id:sample_id + 1] if random_crop: - margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)] + margin = patch_center_dist_from_border - np.asarray(patch_size) // 2 d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin) else: d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s) @@ -362,7 +381,7 @@ def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30, # one scale per case, scale is in percent of patch_size def_scale = np.random.uniform(deformation_scale[0], deformation_scale[1]) - for d in range(len(data[sample_id].shape) - 1): + for d in range(data[sample_id].ndim - 1): # transform relative def_scale in pixels sigmas.append(def_scale * patch_size[d]) @@ -427,7 +446,7 @@ def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30, # now find a nice center location if modified_coords: # recenter coordinates - coords_mean = coords.mean(axis=tuple(range(1, len(coords.shape))), keepdims=True) + coords_mean = coords.mean(axis=tuple(range(1, coords.ndim)), keepdims=True) coords -= coords_mean for d in range(dim): @@ -471,8 +490,8 @@ def augment_transpose_axes(data_sample, seg_sample, axes=(0, 1, 2)): """ axes = list(np.array(axes) + 1) # need list to allow shuffle; +1 to accomodate for color channel - assert np.max(axes) <= len(data_sample.shape), "axes must only contain valid axis ids" - static_axes = list(range(len(data_sample.shape))) + assert np.max(axes) <= data_sample.ndim, "axes must only contain valid axis ids" + static_axes = list(range(data_sample.ndim)) for i in axes: static_axes[i] = -1 np.random.shuffle(axes) @@ -502,7 +521,7 @@ def augment_anatomy_informed(data, seg, t, u, v = get_organ_gradient_field(seg == organ_idx + 2, spacing_ratio=spacing_ratio, blur=blur) - + # TODO: if directions_of_trans[organ_idx][0]: coords[0, :, :, :] = coords[0, :, :, :] + t * dil_magnitude * spacing_ratio if directions_of_trans[organ_idx][1]: diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index f8aac77..b50f3e4 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -14,7 +14,11 @@ # limitations under the License. import random +from functools import lru_cache +from typing import Tuple + import numpy as np +import pandas as pd from copy import deepcopy from scipy.ndimage import map_coordinates, fourier_gaussian from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude @@ -34,11 +38,31 @@ def generate_elastic_transform_coordinates(shape, alpha, sigma): return indices +def get_broadcast_axes(n: int) -> Tuple[int]: + """ + Args: + n: len(array.shape), where array is the array for which we want to broadcast to. + Returns: broadcast axes, (0, 1, ...) + """ + return tuple(range(n - 1)) + + +def reverse_broadcast(a: np.ndarray, axes: Tuple[int]) -> np.ndarray: + """ + Args: + a: array which we want to broadcast for batched operations + axes: (0, 1, ...) + Returns: array of shape (len(a), 1, 1, ...) + """ + return np.expand_dims(a, axes).T + + +@lru_cache(maxsize=None) # There will be only 1 miss, using maxsize None to remove locking and checks. def create_zero_centered_coordinate_mesh(shape): - tmp = tuple([np.arange(i) for i in shape]) - coords = np.array(np.meshgrid(*tmp, indexing='ij')).astype(float) + coords = np.array(np.meshgrid(*(np.arange(i) for i in shape), indexing='ij'), dtype=float) + to_add = (np.array(shape, dtype=float) - 1) / 2. for d in range(len(shape)): - coords[d] -= ((np.array(shape).astype(float) - 1) / 2.)[d] + coords[d] -= to_add[d] return coords @@ -47,10 +71,11 @@ def convert_seg_image_to_one_hot_encoding(image, classes=None): image must be either (x, y, z) or (x, y) Takes as input an nd array of a label map (any dimension). Outputs a one hot encoding of the label map. Example (3D): if input is of shape (x, y, z), the output will ne of shape (n_classes, x, y, z) + Prefer convert_seg_image_to_one_hot_encoding_batched. ''' if classes is None: - classes = np.unique(image) - out_image = np.zeros([len(classes)]+list(image.shape), dtype=image.dtype) + classes = np.sort(pd.unique(image.reshape(-1))) + out_image = np.zeros((len(classes), *image.shape), dtype=image.dtype) for i, c in enumerate(classes): out_image[i][image == c] = 1 return out_image @@ -61,19 +86,16 @@ def convert_seg_image_to_one_hot_encoding_batched(image, classes=None): same as convert_seg_image_to_one_hot_encoding, but expects image to be (b, x, y, z) or (b, x, y) ''' if classes is None: - classes = np.unique(image) - output_shape = [image.shape[0]] + [len(classes)] + list(image.shape[1:]) - out_image = np.zeros(output_shape, dtype=image.dtype) - for b in range(image.shape[0]): - for i, c in enumerate(classes): - out_image[b, i][image[b] == c] = 1 + classes = np.sort(pd.unique(image.reshape(-1))) + out_image = np.zeros((image.shape[0], len(classes), *image.shape[1:]), dtype=image.dtype) + for i, c in enumerate(classes): + out_image[:, i][image == c] = 1 return out_image def elastic_deform_coordinates(coordinates, alpha, sigma): - n_dim = len(coordinates) offsets = [] - for _ in range(n_dim): + for _ in range(len(coordinates)): offsets.append( gaussian_filter((np.random.random(coordinates.shape[1:]) * 2 - 1), sigma, mode="constant", cval=0) * alpha) offsets = np.array(offsets) @@ -100,9 +122,10 @@ def elastic_deform_coordinates_2(coordinates, sigmas, magnitudes): random_values_ = np.fft.fftn(random_values) deformation_field = fourier_gaussian(random_values_, sigmas) deformation_field = np.fft.ifftn(deformation_field).real + mx = np.max(np.abs(deformation_field)) + deformation_field *= (magnitudes[d] + 1e-8) / mx offsets.append(deformation_field) - mx = np.max(np.abs(offsets[-1])) - offsets[-1] = offsets[-1] / (mx / (magnitudes[d] + 1e-8)) + offsets = np.array(offsets) indices = offsets + coordinates return indices @@ -123,34 +146,32 @@ def rotate_coords_2d(coords, angle): return coords -def scale_coords(coords, scale): +def scale_coords(coords: np.ndarray, scale): if isinstance(scale, (tuple, list, np.ndarray)): - assert len(scale) == len(coords) - for i in range(len(scale)): - coords[i] *= scale[i] - else: - coords *= scale + scale = reverse_broadcast(scale, get_broadcast_axes(coords.ndim)) + coords *= scale return coords def uncenter_coords(coords): - shp = coords.shape[1:] + shp = (coords.shape[1:] - 1) / 2. coords = deepcopy(coords) for d in range(coords.shape[0]): - coords[d] += (shp[d] - 1) / 2. + coords[d] += shp[d] return coords def interpolate_img(img, coords, order=3, mode='nearest', cval=0.0, is_seg=False): if is_seg and order != 0: - unique_labels = np.unique(img) + unique_labels = pd.unique(img.reshape(-1)) # does not need sorting result = np.zeros(coords.shape[1:], img.dtype) - for i, c in enumerate(unique_labels): + for c in unique_labels: res_new = map_coordinates((img == c).astype(float), coords, order=order, mode=mode, cval=cval) result[res_new >= 0.5] = c return result else: - return map_coordinates(img.astype(float), coords, order=order, mode=mode, cval=cval).astype(img.dtype) + return map_coordinates( + img.astype(float, copy=False), coords, order=order, mode=mode, cval=cval).astype(img.dtype, copy=False) def generate_noise(shape, alpha, sigma): @@ -160,21 +181,20 @@ def generate_noise(shape, alpha, sigma): def find_entries_in_array(entries, myarray): - entries = np.array(entries) - values = np.arange(np.max(myarray) + 1) - lut = np.zeros(len(values), 'bool') - lut[entries.astype("int")] = True - return np.take(lut, myarray.astype(int)) + entries = np.array(entries, dtype=int) + lut = np.zeros(np.max(myarray) + 1, 'bool') + lut[entries] = True + return np.take(lut, myarray.astype(int, copy=False)) def center_crop_3D_image(img, crop_size): center = np.array(img.shape) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * len(img.shape) + center_crop = [int(crop_size)] * img.ndim else: center_crop = crop_size - assert len(center_crop) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len( + center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" return img[int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] @@ -184,11 +204,11 @@ def center_crop_3D_image_batched(img, crop_size): # dim 0 is batch, dim 1 is channel, dim 2, 3 and 4 are x y z center = np.array(img.shape[2:]) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * (len(img.shape) - 2) + center_crop = [int(crop_size)] * (img.ndim - 2) else: center_crop = crop_size - assert len(center_crop) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(center_crop) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] @@ -197,11 +217,11 @@ def center_crop_3D_image_batched(img, crop_size): def center_crop_2D_image(img, crop_size): center = np.array(img.shape) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * len(img.shape) + center_crop = [int(crop_size)] * img.ndim else: center_crop = crop_size - assert len(center_crop) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len( + center_crop) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" return img[int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] @@ -210,21 +230,21 @@ def center_crop_2D_image_batched(img, crop_size): # dim 0 is batch, dim 1 is channel, dim 2 and 3 are x y center = np.array(img.shape[2:]) / 2. if type(crop_size) not in (tuple, list): - center_crop = [int(crop_size)] * (len(img.shape) - 2) + center_crop = [int(crop_size)] * (img.ndim - 2) else: center_crop = crop_size - assert len(center_crop) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(center_crop) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] def random_crop_3D_image(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * len(img.shape) + crop_size = [crop_size] * img.ndim else: - assert len(crop_size) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len( + crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" if crop_size[0] < img.shape[0]: lb_x = np.random.randint(0, img.shape[0] - crop_size[0]) @@ -252,10 +272,10 @@ def random_crop_3D_image(img, crop_size): def random_crop_3D_image_batched(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * (len(img.shape) - 2) + crop_size = [crop_size] * (img.ndim - 2) else: - assert len(crop_size) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + assert len(crop_size) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" if crop_size[0] < img.shape[2]: lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) @@ -283,10 +303,10 @@ def random_crop_3D_image_batched(img, crop_size): def random_crop_2D_image(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * len(img.shape) + crop_size = [crop_size] * img.ndim else: - assert len(crop_size) == len( - img.shape), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len( + crop_size) == img.ndim, "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" if crop_size[0] < img.shape[0]: lb_x = np.random.randint(0, img.shape[0] - crop_size[0]) @@ -307,10 +327,10 @@ def random_crop_2D_image(img, crop_size): def random_crop_2D_image_batched(img, crop_size): if type(crop_size) not in (tuple, list): - crop_size = [crop_size] * (len(img.shape) - 2) + crop_size = [crop_size] * (img.ndim - 2) else: - assert len(crop_size) == (len( - img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + assert len(crop_size) == ( + img.ndim - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" if crop_size[0] < img.shape[2]: lb_x = np.random.randint(0, img.shape[2] - crop_size[0]) @@ -330,8 +350,8 @@ def random_crop_2D_image_batched(img, crop_size): def resize_image_by_padding(image, new_shape, pad_value=None): - shape = tuple(list(image.shape)) - new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0)) + shape = image.shape + new_shape = np.maximum(shape, new_shape) if pad_value is None: if len(shape) == 2: pad_value = image[0, 0] @@ -339,8 +359,8 @@ def resize_image_by_padding(image, new_shape, pad_value=None): pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") - res = np.ones(list(new_shape), dtype=image.dtype) * pad_value - start = np.array(new_shape) / 2. - np.array(shape) / 2. + res = np.ones(new_shape, dtype=image.dtype) * pad_value + start = new_shape / 2. - np.array(shape) / 2. if len(shape) == 2: res[int(start[0]):int(start[0]) + int(shape[0]), int(start[1]):int(start[1]) + int(shape[1])] = image elif len(shape) == 3: @@ -350,8 +370,8 @@ def resize_image_by_padding(image, new_shape, pad_value=None): def resize_image_by_padding_batched(image, new_shape, pad_value=None): - shape = tuple(list(image.shape[2:])) - new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0)) + shape = image.shape[1:] + new_shape = np.maximum(shape, new_shape) if pad_value is None: if len(shape) == 2: pad_value = image[0, 0] @@ -359,7 +379,7 @@ def resize_image_by_padding_batched(image, new_shape, pad_value=None): pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") - start = np.array(new_shape) / 2. - np.array(shape) / 2. + start = new_shape / 2. - np.array(shape) / 2. if len(shape) == 2: res = np.ones((image.shape[0], image.shape[1], new_shape[0], new_shape[1]), dtype=image.dtype) * pad_value res[:, :, int(start[0]):int(start[0]) + int(shape[0]), int(start[1]):int(start[1]) + int(shape[1])] = image[:, @@ -374,39 +394,38 @@ def resize_image_by_padding_batched(image, new_shape, pad_value=None): return res -def create_matrix_rotation_x_3d(angle, matrix=None): - rotation_x = np.array([[1, 0, 0], - [0, np.cos(angle), -np.sin(angle)], - [0, np.sin(angle), np.cos(angle)]]) - if matrix is None: - return rotation_x - +def create_matrix_rotation_x_3d(angle, matrix: np.ndarray): + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation_x = np.array(((1, 0, 0), + (0, cos_a, -sin_a), + (0, sin_a, cos_a))) return np.dot(matrix, rotation_x) -def create_matrix_rotation_y_3d(angle, matrix=None): - rotation_y = np.array([[np.cos(angle), 0, np.sin(angle)], - [0, 1, 0], - [-np.sin(angle), 0, np.cos(angle)]]) - if matrix is None: - return rotation_y - +def create_matrix_rotation_y_3d(angle, matrix: np.ndarray): + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation_y = np.array(((cos_a, 0, sin_a), + (0, 1, 0), + (-sin_a, 0, cos_a))) return np.dot(matrix, rotation_y) -def create_matrix_rotation_z_3d(angle, matrix=None): - rotation_z = np.array([[np.cos(angle), -np.sin(angle), 0], - [np.sin(angle), np.cos(angle), 0], - [0, 0, 1]]) - if matrix is None: - return rotation_z - +def create_matrix_rotation_z_3d(angle, matrix: np.ndarray): + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation_z = np.array(((cos_a, -sin_a, 0), + (sin_a, cos_a, 0), + (0, 0, 1))) return np.dot(matrix, rotation_z) def create_matrix_rotation_2d(angle, matrix=None): - rotation = np.array([[np.cos(angle), -np.sin(angle)], - [np.sin(angle), np.cos(angle)]]) + cos_a = np.cos(angle) + sin_a = np.sin(angle) + rotation = np.array(((cos_a, -sin_a), + (sin_a, cos_a))) if matrix is None: return rotation @@ -476,111 +495,111 @@ def general_cc_var_num_channels(img, diff_order=0, mink_norm=1, sigma=1, mask_im for c in range(img_internal.shape[0]): white_colors.append(np.max(img_internal[c][mask_im != 1])) - som = np.sqrt(np.sum([i ** 2 for i in white_colors])) + white_colors = np.array(white_colors) + som = np.sqrt(np.sum(np.power(white_colors, 2))) - white_colors = [i / som for i in white_colors] + white_colors *= np.sqrt(3.) / som for c in range(output_img.shape[0]): - output_img[c] /= (white_colors[c] * np.sqrt(3.)) + output_img[c] /= white_colors[c] if clip_range: - output_img[output_img < minm] = minm - output_img[output_img > maxm] = maxm + np.clip(output_img, minm, maxm, out=output_img) + return white_colors, output_img -def convert_seg_to_bounding_box_coordinates(data_dict, dim, get_rois_from_seg_flag=False, class_specific_seg_flag=False): - - ''' - This function generates bounding box annotations from given pixel-wise annotations. - :param data_dict: Input data dictionary as returned by the batch generator. - :param dim: Dimension in which the model operates (2 or 3). - :param get_rois_from_seg: Flag specifying one of the following scenarios: - 1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing - in each position the class target for the lesion with the corresponding label (set flag to False) - 2. A binary label map. There is only one foreground class and single lesions are not identified. - All lesions have the same class target (foreground). In this case the Dataloader runs a Connected Component - Labelling algorithm to create processable lesion - class target pairs on the fly (set flag to True). - :param class_specific_seg_flag: if True, returns the pixelwise-annotations in class specific manner, - e.g. a multi-class label map. If False, returns a binary annotation map (only foreground vs. background). - :return: data_dict: same as input, with additional keys: - - 'bb_target': bounding box coordinates (b, n_boxes, (y1, x1, y2, x2, (z1), (z2))) - - 'roi_labels': corresponding class labels for each box (b, n_boxes, class_label) - - 'roi_masks': corresponding binary segmentation mask for each lesion (box). Only used in Mask RCNN. (b, n_boxes, y, x, (z)) - - 'seg': now label map (see class_specific_seg_flag) - ''' - - bb_target = [] - roi_masks = [] - roi_labels = [] - out_seg = np.copy(data_dict['seg']) - for b in range(data_dict['seg'].shape[0]): - - p_coords_list = [] - p_roi_masks_list = [] - p_roi_labels_list = [] - - if np.sum(data_dict['seg'][b]!=0) > 0: - if get_rois_from_seg_flag: - clusters, n_cands = lb(data_dict['seg'][b]) - data_dict['class_target'][b] = [data_dict['class_target'][b]] * n_cands - else: - n_cands = int(np.max(data_dict['seg'][b])) - clusters = data_dict['seg'][b] - - rois = np.array([(clusters == ii) * 1 for ii in range(1, n_cands + 1)]) # separate clusters and concat - for rix, r in enumerate(rois): - if np.sum(r !=0) > 0: #check if the lesion survived data augmentation - seg_ixs = np.argwhere(r != 0) - coord_list = [np.min(seg_ixs[:, 1])-1, np.min(seg_ixs[:, 2])-1, np.max(seg_ixs[:, 1])+1, - np.max(seg_ixs[:, 2])+1] - if dim == 3: - - coord_list.extend([np.min(seg_ixs[:, 3])-1, np.max(seg_ixs[:, 3])+1]) - - p_coords_list.append(coord_list) - p_roi_masks_list.append(r) - # add background class = 0. rix is a patient wide index of lesions. since 'class_target' is - # also patient wide, this assignment is not dependent on patch occurrances. - p_roi_labels_list.append(data_dict['class_target'][b][rix] + 1) - - if class_specific_seg_flag: - out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_target'][b][rix] + 1 - - if not class_specific_seg_flag: - out_seg[b][data_dict['seg'][b] > 0] = 1 - - bb_target.append(np.array(p_coords_list)) - roi_masks.append(np.array(p_roi_masks_list).astype('uint8')) - roi_labels.append(np.array(p_roi_labels_list)) +def convert_seg_to_bounding_box_coordinates(data_dict, dim, get_rois_from_seg_flag=False, + class_specific_seg_flag=False): + ''' + This function generates bounding box annotations from given pixel-wise annotations. + :param data_dict: Input data dictionary as returned by the batch generator. + :param dim: Dimension in which the model operates (2 or 3). + :param get_rois_from_seg: Flag specifying one of the following scenarios: + 1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing + in each position the class target for the lesion with the corresponding label (set flag to False) + 2. A binary label map. There is only one foreground class and single lesions are not identified. + All lesions have the same class target (foreground). In this case the Dataloader runs a Connected Component + Labelling algorithm to create processable lesion - class target pairs on the fly (set flag to True). + :param class_specific_seg_flag: if True, returns the pixelwise-annotations in class specific manner, + e.g. a multi-class label map. If False, returns a binary annotation map (only foreground vs. background). + :return: data_dict: same as input, with additional keys: + - 'bb_target': bounding box coordinates (b, n_boxes, (y1, x1, y2, x2, (z1), (z2))) + - 'roi_labels': corresponding class labels for each box (b, n_boxes, class_label) + - 'roi_masks': corresponding binary segmentation mask for each lesion (box). Only used in Mask RCNN. (b, n_boxes, y, x, (z)) + - 'seg': now label map (see class_specific_seg_flag) + ''' + + bb_target = [] + roi_masks = [] + roi_labels = [] + out_seg = np.copy(data_dict['seg']) + for b in range(data_dict['seg'].shape[0]): + p_coords_list = [] + p_roi_masks_list = [] + p_roi_labels_list = [] + if np.sum(data_dict['seg'][b] != 0) > 0: + if get_rois_from_seg_flag: + clusters, n_cands = lb(data_dict['seg'][b]) + data_dict['class_target'][b] = [data_dict['class_target'][b]] * n_cands else: - bb_target.append([]) - roi_masks.append(np.zeros_like(data_dict['seg'][b])[None]) - roi_labels.append(np.array([-1])) + n_cands = int(np.max(data_dict['seg'][b])) + clusters = data_dict['seg'][b] + + rois = np.array([(clusters == ii) * 1 for ii in range(1, n_cands + 1)]) # separate clusters and concat + for rix, r in enumerate(rois): + if np.sum(r != 0) > 0: # check if the lesion survived data augmentation + seg_ixs = np.argwhere(r != 0) + coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, + np.max(seg_ixs[:, 2]) + 1] + if dim == 3: + coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1]) + + p_coords_list.append(coord_list) + p_roi_masks_list.append(r) + # add background class = 0. rix is a patient wide index of lesions. since 'class_target' is + # also patient wide, this assignment is not dependent on patch occurrances. + p_roi_labels_list.append(data_dict['class_target'][b][rix] + 1) - if get_rois_from_seg_flag: - data_dict.pop('class_target', None) + if class_specific_seg_flag: + out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_target'][b][rix] + 1 - data_dict['bb_target'] = np.array(bb_target) - data_dict['roi_masks'] = np.array(roi_masks) - data_dict['class_target'] = np.array(roi_labels) - data_dict['seg'] = out_seg + if not class_specific_seg_flag: + out_seg[b][data_dict['seg'][b] > 0] = 1 - return data_dict + bb_target.append(np.array(p_coords_list)) + roi_masks.append(np.array(p_roi_masks_list, dtype=np.uint8)) + roi_labels.append(np.array(p_roi_labels_list)) + + + else: + bb_target.append([]) + roi_masks.append(np.zeros_like(data_dict['seg'][b])[None]) + roi_labels.append(np.array([-1])) + + if get_rois_from_seg_flag: + data_dict.pop('class_target', None) + + data_dict['bb_target'] = np.array(bb_target) + data_dict['roi_masks'] = np.array(roi_masks) + data_dict['class_target'] = np.array(roi_labels) + data_dict['seg'] = out_seg + + return data_dict def transpose_channels(batch): - if len(batch.shape) == 4: + if batch.ndim == 4: return np.transpose(batch, axes=[0, 2, 3, 1]) - elif len(batch.shape) == 5: + elif batch.ndim == 5: return np.transpose(batch, axes=[0, 4, 2, 3, 1]) else: raise ValueError("wrong dimensions in transpose_channel generator!") -def resize_segmentation(segmentation, new_shape, order=3): +def resize_segmentation(segmentation, new_shape: tuple, order=3): ''' Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one hot encoding which is resized and transformed back to a segmentation map. @@ -591,21 +610,23 @@ def resize_segmentation(segmentation, new_shape, order=3): :return: ''' tpe = segmentation.dtype - unique_labels = np.unique(segmentation) - assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + assert segmentation.ndim == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: - return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe) + return resize(segmentation.astype(np.float64, copy=False), new_shape, order, mode="edge", clip=True, + anti_aliasing=False).astype(tpe, copy=False) else: - reshaped = np.zeros(new_shape, dtype=segmentation.dtype) + unique_labels = pd.unique(segmentation.reshape(-1)) # does not need sorting + reshaped = np.zeros(new_shape, dtype=tpe) - for i, c in enumerate(unique_labels): + for c in unique_labels: mask = segmentation == c - reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped_multihot = resize(mask.astype(np.float64, copy=False), new_shape, order, mode="edge", clip=True, + anti_aliasing=False) reshaped[reshaped_multihot >= 0.5] = c return reshaped -def resize_multichannel_image(multichannel_image, new_shape, order=3): +def resize_multichannel_image(multichannel_image, new_shape: tuple, order=3): ''' Resizes multichannel_image. Resizes each channel in c separately and fuses results back together @@ -614,12 +635,12 @@ def resize_multichannel_image(multichannel_image, new_shape, order=3): :param order: :return: ''' - tpe = multichannel_image.dtype - new_shp = [multichannel_image.shape[0]] + list(new_shape) + new_shp = (multichannel_image.shape[0],) + new_shape result = np.zeros(new_shp, dtype=multichannel_image.dtype) for i in range(multichannel_image.shape[0]): - result[i] = resize(multichannel_image[i].astype(float), new_shape, order, clip=True, anti_aliasing=False) - return result.astype(tpe) + result[i] = resize(multichannel_image[i].astype(float, copy=False), new_shape, order, clip=True, + anti_aliasing=False) + return result def get_range_val(value, rnd_type="uniform"): @@ -654,12 +675,13 @@ def uniform(low, high, size=None): if size is None: return low else: - return np.ones(size) * low + return np.full(size, low) else: return np.random.uniform(low, high, size) -def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): +def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_slicer=False, + shape_must_be_divisible_by=None): """ one padder to pad them all. Documentation? Well okay. A little bit @@ -690,12 +712,8 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli new_shape = image.shape[-len(shape_must_be_divisible_by):] old_shape = new_shape - num_axes_nopad = len(image.shape) - len(new_shape) - - new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] - - if not isinstance(new_shape, np.ndarray): - new_shape = np.array(new_shape) + num_axes_nopad = image.ndim - len(new_shape) + new_shape = np.maximum(new_shape, old_shape) if shape_must_be_divisible_by is not None: if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): @@ -704,17 +722,16 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli assert len(shape_must_be_divisible_by) == len(new_shape) for i in range(len(new_shape)): - if new_shape[i] % shape_must_be_divisible_by[i] == 0: - new_shape[i] -= shape_must_be_divisible_by[i] - - new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) + modulo = new_shape[i] % shape_must_be_divisible_by[i] + if modulo != 0: + new_shape[i] += shape_must_be_divisible_by[i] - modulo difference = new_shape - old_shape pad_below = difference // 2 - pad_above = difference // 2 + difference % 2 - pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) + pad_above = pad_below + difference % 2 + pad_list = [[0, 0]] * num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) - if not ((all([i == 0 for i in pad_below])) and (all([i == 0 for i in pad_above]))): + if np.any(pad_below) or np.any(pad_above): res = np.pad(image, pad_list, mode, **kwargs) else: res = image @@ -724,7 +741,7 @@ def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_sli else: pad_list = np.array(pad_list) pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] - slicer = list(slice(*i) for i in pad_list) + slicer = [slice(*i) for i in pad_list] return res, slicer @@ -745,23 +762,23 @@ def mask_random_square(img, square_size, n_val, channel_wise_n_val=False, square h_start = pos_wh[1] if img.ndim == 2: - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val elif img.ndim == 3: if channel_wise_n_val: for i in range(img.shape[0]): - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[i, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val else: - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[:, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val elif img.ndim == 4: if channel_wise_n_val: for i in range(img.shape[0]): - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[:, i, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val else: - rnd_n_val = get_range_val(n_val) + rnd_n_val = uniform(n_val[0], n_val[1]) img[:, :, h_start:(h_start + square_size), w_start:(w_start + square_size)] = rnd_n_val return img @@ -774,7 +791,8 @@ def mask_random_squares(img, square_size, n_squares, n_val, channel_wise_n_val=F square_pos=square_pos) return img -def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32): + +def get_organ_gradient_field(organ, spacing_ratio=0.3125 / 3.0, blur=32): """ Calculates the gradient field around the organ segmentations for the anatomy-informed augmentation @@ -782,7 +800,7 @@ def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32): :param spacing_ratio: ratio of the axial spacing and the slice thickness, needed for the right vector field calculation :param blur: kernel constant """ - organ_blurred = gaussian_filter(organ.astype(float), + organ_blurred = gaussian_filter(organ.astype(float, copy=False), sigma=(blur * spacing_ratio, blur, blur), order=0, mode='nearest') @@ -792,6 +810,7 @@ def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32): return t, u, v + def ignore_anatomy(segm, max_annotation_value=1, replace_value=0): segm[segm > max_annotation_value] = replace_value return segm diff --git a/batchgenerators/dataloading/data_loader.py b/batchgenerators/dataloading/data_loader.py index 28d0df1..581ecd4 100644 --- a/batchgenerators/dataloading/data_loader.py +++ b/batchgenerators/dataloading/data_loader.py @@ -14,7 +14,6 @@ # limitations under the License. from abc import ABCMeta, abstractmethod -from builtins import object import warnings from collections import OrderedDict from warnings import warn @@ -169,6 +168,10 @@ def __init__(self, data, batch_size, num_threads_in_multithreaded=1, seed_for_sh # when you derive, make sure to set this! We can't set it here because we don't know what data will be like self.indices = None + if self.infinite: + # Use separate get indices method + self.get_indices = self.get_indices_infinite + def reset(self): assert self.indices is not None @@ -182,11 +185,10 @@ def reset(self): self.last_reached = False - def get_indices(self): - # if self.infinite, this is easy - if self.infinite: - return np.random.choice(self.indices, self.batch_size, replace=True, p=self.sampling_probabilities) + def get_indices_infinite(self): + return np.random.choice(self.indices, self.batch_size, replace=True, p=self.sampling_probabilities) + def get_indices(self): if self.last_reached: self.reset() raise StopIteration @@ -199,7 +201,6 @@ def get_indices(self): for b in range(self.batch_size): if self.current_position < len(self.indices): indices.append(self.indices[self.current_position]) - self.current_position += 1 else: self.last_reached = True @@ -230,11 +231,11 @@ def default_collate(batch): if isinstance(batch[0], np.ndarray): return np.vstack(batch) elif isinstance(batch[0], (int, np.int64)): - return np.array(batch).astype(np.int32) + return np.array(batch, dtype=np.int32) elif isinstance(batch[0], (float, np.float32)): - return np.array(batch).astype(np.float32) + return np.array(batch, dtype=np.float32) elif isinstance(batch[0], (np.float64,)): - return np.array(batch).astype(np.float64) + return np.array(batch, dtype=np.float64) elif isinstance(batch[0], (dict, OrderedDict)): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], (tuple, list)): diff --git a/batchgenerators/dataloading/multi_threaded_augmenter.py b/batchgenerators/dataloading/multi_threaded_augmenter.py index 6006fbe..8fde7b9 100755 --- a/batchgenerators/dataloading/multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/multi_threaded_augmenter.py @@ -61,7 +61,7 @@ def producer(queue, data_loader, transform, thread_id, seed, abort_event, wait_t abort_event.set() return except Exception as e: - print("Exception in background worker %d:\n" % thread_id, e) + print(f"Exception in background worker {thread_id}:\n", e) traceback.print_exc() abort_event.set() return @@ -216,7 +216,7 @@ def __next__(self): return item except KeyboardInterrupt: - logging.error("MultiThreadedGenerator: caught exception: {}".format(sys.exc_info())) + logging.error(f"MultiThreadedGenerator: caught exception: {sys.exc_info()}") self.abort_event.set() self._finish() raise KeyboardInterrupt diff --git a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py index 530eba1..d02074b 100755 --- a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py @@ -17,7 +17,6 @@ from copy import deepcopy from typing import List, Union import threading -from builtins import range from multiprocessing import Process from multiprocessing import Queue from queue import Queue as thrQueue @@ -68,7 +67,7 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed, return except Exception as e: - print("Exception in background worker %d:\n" % thread_id, e) + print(f"Exception in background worker {thread_id}:\n", e) traceback.print_exc() abort_event.set() return diff --git a/batchgenerators/dataloading/single_threaded_augmenter.py b/batchgenerators/dataloading/single_threaded_augmenter.py index 5637c0e..f27cce2 100755 --- a/batchgenerators/dataloading/single_threaded_augmenter.py +++ b/batchgenerators/dataloading/single_threaded_augmenter.py @@ -40,3 +40,6 @@ def __next__(self): def next(self): return self.__next__() + + def _finish(self): + pass diff --git a/batchgenerators/transforms/abstract_transforms.py b/batchgenerators/transforms/abstract_transforms.py index 2a535cc..69902f5 100644 --- a/batchgenerators/transforms/abstract_transforms.py +++ b/batchgenerators/transforms/abstract_transforms.py @@ -28,7 +28,7 @@ def __call__(self, **data_dict): def __repr__(self): ret_str = str(type(self).__name__) + "( " + ", ".join( - [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )" + [f"{key} = {repr(val)}" for key, val in self.__dict__.items()]) + " )" return ret_str @@ -89,4 +89,4 @@ def __call__(self, **data_dict): return data_dict def __repr__(self): - return str(type(self).__name__) + " ( " + repr(self.transforms) + " )" + return f"{str(type(self).__name__)} ( {repr(self.transforms)} )" diff --git a/batchgenerators/transforms/channel_selection_transforms.py b/batchgenerators/transforms/channel_selection_transforms.py index ec89cfe..5bd1ae1 100644 --- a/batchgenerators/transforms/channel_selection_transforms.py +++ b/batchgenerators/transforms/channel_selection_transforms.py @@ -15,6 +15,9 @@ import numpy as np from warnings import warn + +import pandas as pd + from batchgenerators.transforms.abstract_transforms import AbstractTransform @@ -118,7 +121,7 @@ def __call__(self, **data_dict): random_number = np.random.rand() if random_number < self.swap_probability: seg[:, [self.axis1, self.axis2]] = seg[:, [self.axis2, self.axis1]] - data_dict[self.label_key] = seg + data_dict[self.label_key] = seg return data_dict @@ -167,6 +170,7 @@ def __init__(self, label, label_key="seg"): self.label = [label] else: self.label = sorted(label) + self.label = set(self.label) def __call__(self, **data_dict): seg = data_dict.get(self.label_key) @@ -175,7 +179,8 @@ def __call__(self, **data_dict): warn("You used SegLabelSelectionBinarizeTransform but there is no 'seg' key in your data_dict, returning " "data_dict unmodified", Warning) else: - discard_labels = set(np.unique(seg)) - set(self.label) - set([0]) + + discard_labels = set(pd.unique(seg.reshape(-1))) - self.label - {0} for label in discard_labels: seg[seg == label] = 0 for label in self.label: diff --git a/batchgenerators/transforms/color_transforms.py b/batchgenerators/transforms/color_transforms.py index e5b735d..3d294ab 100644 --- a/batchgenerators/transforms/color_transforms.py +++ b/batchgenerators/transforms/color_transforms.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Tuple, Callable +from typing import Tuple import numpy as np @@ -24,7 +24,7 @@ class ContrastAugmentationTransform(AbstractTransform): def __init__(self, - contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), + contrast_range: Tuple[float, float] = (0.75, 1.25), preserve_range: bool = True, per_channel: bool = True, data_key: str = "data", @@ -36,7 +36,6 @@ def __init__(self, (float, float): range from which to sample a random contrast that is applied to the data. If one value is smaller and one is larger than 1, half of the contrast modifiers will be >1 and the other half <1 (in the inverval that was specified) - callable : must be contrast_range() -> float :param preserve_range: if True then the intensity values after contrast augmentation will be cropped to min and max values of the data before augmentation. :param per_channel: whether to use the same contrast modifier for all color channels or a separate one for each @@ -52,13 +51,14 @@ def __init__(self, self.p_per_channel = p_per_channel def __call__(self, **data_dict): - for b in range(len(data_dict[self.data_key])): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b] = augment_contrast(data_dict[self.data_key][b], - contrast_range=self.contrast_range, - preserve_range=self.preserve_range, - per_channel=self.per_channel, - p_per_channel=self.p_per_channel) + mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample + if np.any(mask): + data_dict[self.data_key][mask] = augment_contrast(data_dict[self.data_key][mask], + contrast_range=self.contrast_range, + preserve_range=self.preserve_range, + per_channel=self.per_channel, + p_per_channel=self.p_per_channel, + batched=True) return data_dict @@ -121,17 +121,19 @@ def __init__(self, multiplier_range=(0.5, 2), per_channel=True, data_key="data", self.per_channel = per_channel def __call__(self, **data_dict): - for b in range(len(data_dict[self.data_key])): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b] = augment_brightness_multiplicative(data_dict[self.data_key][b], - self.multiplier_range, - self.per_channel) + data = data_dict[self.data_key] + mask = np.random.uniform(size=len(data)) < self.p_per_sample + if np.any(mask): + data_dict[self.data_key][mask] = augment_brightness_multiplicative(data[mask], + self.multiplier_range, + self.per_channel, + batched=True) return data_dict class GammaTransform(AbstractTransform): def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, data_key="data", - retain_stats: Union[bool, Callable[[], bool]] = False, p_per_sample=1): + retain_stats: bool = False, p_per_sample=1): """ Augments by changing 'gamma' of the image (same as gamma correction in photos or computer monitors @@ -143,8 +145,7 @@ def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, :param per_channel: :param data_key: :param retain_stats: Gamma transformation will alter the mean and std of the data in the patch. If retain_stats=True, - the data will be transformed to match the mean and standard deviation before gamma augmentation. retain_stats - can also be callable (signature retain_stats() -> bool) + the data will be transformed to match the mean and standard deviation before gamma augmentation. :param p_per_sample: """ self.p_per_sample = p_per_sample @@ -155,6 +156,7 @@ def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, self.invert_image = invert_image def __call__(self, **data_dict): + # TODO: augment_gamma can be vectorized twice (per channel and per sample) for b in range(len(data_dict[self.data_key])): if np.random.uniform() < self.p_per_sample: data_dict[self.data_key][b] = augment_gamma(data_dict[self.data_key][b], self.gamma_range, @@ -203,5 +205,5 @@ def __init__(self, min=None, max=None, data_key="data"): self.max = max def __call__(self, **data_dict): - data_dict[self.data_key] = np.clip(data_dict[self.data_key], self.min, self.max) + np.clip(data_dict[self.data_key], self.min, self.max, out=data_dict[self.data_key]) return data_dict diff --git a/batchgenerators/transforms/crop_and_pad_transforms.py b/batchgenerators/transforms/crop_and_pad_transforms.py index 9735774..87b20cc 100644 --- a/batchgenerators/transforms/crop_and_pad_transforms.py +++ b/batchgenerators/transforms/crop_and_pad_transforms.py @@ -128,7 +128,7 @@ def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) - assert len(self.new_size) + 2 == len(data.shape), "new size must be a tuple/list/np.ndarray with shape " \ + assert len(self.new_size) + 2 == data.ndim, "new size must be a tuple/list/np.ndarray with shape " \ "(x, y(, z))" data, seg = pad_nd_image_and_seg(data, seg, self.new_size, None, np_pad_kwargs_data=self.np_pad_kwargs_data, @@ -180,8 +180,8 @@ def __call__(self, **data_dict): for c in range(workon.shape[1]): if np.random.uniform(0, 1) < self.p_per_channel: shift_here = [] - for d in range(len(workon.shape) - 2): - shift_here.append(int(np.round(np.random.normal( + for d in range(workon.ndim - 2): + shift_here.append(int(np.rint(np.random.normal( self.shift_mu[d] if isinstance(self.shift_mu, (list, tuple)) else self.shift_mu, self.shift_sigma[d] if isinstance(self.shift_sigma, (list, tuple)) else self.shift_sigma, diff --git a/batchgenerators/transforms/local_transforms.py b/batchgenerators/transforms/local_transforms.py index 8cb9ca4..9c69dff 100644 --- a/batchgenerators/transforms/local_transforms.py +++ b/batchgenerators/transforms/local_transforms.py @@ -61,8 +61,8 @@ def _generate_kernel(self, img_shp: Tuple[int, ...]) -> np.ndarray: kernel_image = kernel_2d # normalize to [0, 1] - kernel_image = kernel_image - kernel_image.min() - kernel_image = kernel_image / max(1e-8, kernel_image.max()) + kernel_image -= kernel_image.min() + kernel_image /= max(1e-8, kernel_image.max()) return kernel_image def _generate_multiple_kernel_image(self, img_shp: Tuple[int, ...], num_kernels: int) -> np.ndarray: @@ -150,7 +150,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: @@ -167,7 +167,7 @@ def __call__(self, **data_dict): # now rescale so that the maximum value of the kernel is max_strength strength = sample_scalar(self.max_strength, data[bi, ci], kernel) if callable( self.max_strength) else strength - kernel_scaled = np.copy(kernel) / mx * strength + kernel_scaled = kernel / mx * strength data[bi, ci] += kernel_scaled else: for ci in range(c): @@ -177,7 +177,7 @@ def __call__(self, **data_dict): kernel -= kernel.mean() mx = max(np.max(np.abs(kernel)), 1e-8) strength = sample_scalar(self.max_strength, data[bi, ci], kernel) - kernel = kernel / mx * strength + kernel *= strength / mx data[bi, ci] += kernel return data_dict @@ -235,7 +235,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: @@ -255,14 +255,15 @@ def __call__(self, **data_dict): def _apply_gamma_gradient(self, img: np.ndarray, kernel: np.ndarray) -> np.ndarray: # store keep original image range mn, mx = img.min(), img.max() + rng = mx - mn # rescale tp [0, 1] - img = (img - mn) / (max(mx - mn, 1e-8)) + img = (img - mn) / (max(rng, 1e-8)) gamma = sample_scalar(self.gamma) img_modified = np.power(img, gamma) - return self.run_interpolation(img, img_modified, kernel) * (mx - mn) + mn + return self.run_interpolation(img, img_modified, kernel) * rng + mn class LocalSmoothingTransform(LocalTransform): @@ -301,7 +302,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: @@ -323,7 +324,7 @@ def _apply_local_smoothing(self, img: np.ndarray, kernel: np.ndarray) -> np.ndar kernel = np.copy(kernel) smoothing = sample_scalar(self.smoothing_strength) - assert 0 <= smoothing <= 1, 'smoothing_strength must be between 0 and 1, is %f' % smoothing + assert 0 <= smoothing <= 1, f'smoothing_strength must be between 0 and 1, is {smoothing}' # prepare kernel by rescaling it to gamma_range # kernel is already [0, 1] @@ -353,7 +354,7 @@ def __init__(self, def __call__(self, **data_dict): data = data_dict.get(self.data_key) - assert data is not None, "Could not find data key '%s'" % self.data_key + assert data is not None, f"Could not find data key '{self.data_key}'" b, c, *img_shape = data.shape for bi in range(b): if np.random.uniform() < self.p_per_sample: diff --git a/batchgenerators/transforms/noise_transforms.py b/batchgenerators/transforms/noise_transforms.py index ba67efe..e2e40c3 100644 --- a/batchgenerators/transforms/noise_transforms.py +++ b/batchgenerators/transforms/noise_transforms.py @@ -20,7 +20,6 @@ import numpy as np from typing import Union, Tuple -from scipy import ndimage from scipy.ndimage import median_filter from scipy.signal import convolve @@ -71,10 +70,10 @@ def __init__(self, noise_variance=(0, 0.1), p_per_sample=1, p_per_channel: float self.per_channel = per_channel def __call__(self, **data_dict): - for b in range(len(data_dict[self.data_key])): - if np.random.uniform() < self.p_per_sample: - data_dict[self.data_key][b] = augment_gaussian_noise(data_dict[self.data_key][b], self.noise_variance, - self.p_per_channel, self.per_channel) + mask = np.random.uniform(size=len(data_dict[self.data_key])) < self.p_per_sample + if np.any(mask): + data_dict[self.data_key][mask] = augment_gaussian_noise(data_dict[self.data_key][mask], self.noise_variance, + self.p_per_channel, self.per_channel, batched=True) return data_dict @@ -102,6 +101,7 @@ def __init__(self, blur_sigma: Tuple[float, float] = (1, 5), different_sigma_per self.p_isotropic = p_isotropic def __call__(self, **data_dict): + # TODO: Do batched gaussian blur for b in range(len(data_dict[self.data_key])): if np.random.uniform() < self.p_per_sample: data_dict[self.data_key][b] = augment_gaussian_blur(data_dict[self.data_key][b], self.blur_sigma, @@ -138,6 +138,7 @@ def __init__(self, rectangle_value): self.rectangle_value = rectangle_value def __call__(self, x): + # TODO: Change this if np.isscalar(self.rectangle_value): return self.rectangle_value elif callable(self.rectangle_value): @@ -148,6 +149,8 @@ def __call__(self, x): raise RuntimeError("unrecognized format for rectangle_value") + + class BlankRectangleTransform(AbstractTransform): def __init__(self, rectangle_size, rectangle_value, num_rectangles, force_square=False, p_per_sample=0.5, p_per_channel=0.5, apply_to_keys=('data',)): @@ -319,7 +322,7 @@ def __call__(self, **data_dict): mn, mx = data[b].min(), data[b].max() strength_here = self.strength if isinstance(self.strength, float) else np.random.uniform( *self.strength) - if len(data.shape) == 4: + if data.ndim == 4: filter_here = self.filter_2d * strength_here filter_here[1, 1] += 1 else: @@ -331,14 +334,14 @@ def __call__(self, **data_dict): filter_here, mode='same' ) - data[b, c] = np.clip(data[b, c], mn, mx) + np.clip(data[b, c], mn, mx, out=data[b, c]) else: for c in range(data.shape[1]): if np.random.uniform() < self.p_per_channel: mn, mx = data[b, c].min(), data[b, c].max() strength_here = self.strength if isinstance(self.strength, float) else np.random.uniform( *self.strength) - if len(data.shape) == 4: + if data.ndim == 4: filter_here = self.filter_2d * strength_here filter_here[1, 1] += 1 else: @@ -348,7 +351,7 @@ def __call__(self, **data_dict): filter_here, mode='same' ) - data[b, c] = np.clip(data[b, c], mn, mx) + np.clip(data[b, c], mn, mx, out=data[b, c]) return data_dict diff --git a/batchgenerators/transforms/resample_transforms.py b/batchgenerators/transforms/resample_transforms.py index 5b00c0c..bc6cf09 100644 --- a/batchgenerators/transforms/resample_transforms.py +++ b/batchgenerators/transforms/resample_transforms.py @@ -57,6 +57,9 @@ def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1, self.p_per_channel = p_per_channel self.p_per_sample = p_per_sample self.data_key = data_key + assert isinstance(zoom_range, (tuple, list, np.ndarray)) + assert (len(zoom_range) == 2 or isinstance(zoom_range[0], (tuple, list, np.ndarray)) and + all(len(zoom) == 2 for zoom in zoom_range)) self.zoom_range = zoom_range self.ignore_axes = ignore_axes diff --git a/batchgenerators/transforms/spatial_transforms.py b/batchgenerators/transforms/spatial_transforms.py index 0d766ae..86c95a8 100644 --- a/batchgenerators/transforms/spatial_transforms.py +++ b/batchgenerators/transforms/spatial_transforms.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from batchgenerators.transforms.abstract_transforms import AbstractTransform -from batchgenerators.augmentations.spatial_transformations import augment_spatial, augment_spatial_2, \ - augment_channel_translation, \ - augment_mirroring, augment_transpose_axes, augment_zoom, augment_resize, augment_rot90, \ - augment_anatomy_informed, augment_misalign import numpy as np -from batchgenerators.augmentations.utils import get_organ_gradient_field + +from batchgenerators.augmentations.spatial_transformations import augment_spatial, augment_spatial_2, \ + augment_channel_translation, augment_mirroring_batched, augment_transpose_axes, augment_zoom, augment_resize, \ + augment_rot90, augment_anatomy_informed, augment_misalign +from batchgenerators.transforms.abstract_transforms import AbstractTransform class Rot90Transform(AbstractTransform): @@ -202,22 +201,17 @@ def __init__(self, axes=(0, 1, 2), data_key="data", label_key="seg", p_per_sampl "is now axes=(0, 1, 2). Please adapt your scripts accordingly.") def __call__(self, **data_dict): - data = data_dict.get(self.data_key) + data = data_dict[self.data_key] seg = data_dict.get(self.label_key) - for b in range(len(data)): - if np.random.uniform() < self.p_per_sample: - sample_seg = None - if seg is not None: - sample_seg = seg[b] - ret_val = augment_mirroring(data[b], sample_seg, axes=self.axes) - data[b] = ret_val[0] - if seg is not None: - seg[b] = ret_val[1] - - data_dict[self.data_key] = data - if seg is not None: - data_dict[self.label_key] = seg + mask = np.random.uniform(size=len(data)) < self.p_per_sample + if np.any(mask): + if seg is None: + data[mask], _ = augment_mirroring_batched(data[mask], None, self.axes) + else: + data[mask], seg[mask] = augment_mirroring_batched(data[mask], seg[mask], self.axes) + data_dict[self.label_key] = seg + data_dict[self.data_key] = data return data_dict @@ -303,14 +297,15 @@ def __init__(self, patch_size, patch_center_dist_from_border=30, do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, data_key="data", label_key="seg", p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, - independent_scale_for_each_axis=False, p_rot_per_axis:float=1, p_independent_scale_per_axis: int=1): + independent_scale_for_each_axis=False, p_rot_per_axis: float = 1, + p_independent_scale_per_axis: int = 1): self.independent_scale_for_each_axis = independent_scale_for_each_axis self.p_rot_per_sample = p_rot_per_sample self.p_scale_per_sample = p_scale_per_sample self.p_el_per_sample = p_el_per_sample self.data_key = data_key self.label_key = label_key - self.patch_size = patch_size + self.patch_size = tuple(patch_size) self.patch_center_dist_from_border = patch_center_dist_from_border self.do_elastic_deform = do_elastic_deform self.alpha = alpha @@ -332,14 +327,14 @@ def __init__(self, patch_size, patch_center_dist_from_border=30, self.p_independent_scale_per_axis = p_independent_scale_per_axis def __call__(self, **data_dict): - data = data_dict.get(self.data_key) + data = data_dict[self.data_key] seg = data_dict.get(self.label_key) if self.patch_size is None: - if len(data.shape) == 4: - patch_size = (data.shape[2], data.shape[3]) - elif len(data.shape) == 5: - patch_size = (data.shape[2], data.shape[3], data.shape[4]) + if data.ndim == 4: + patch_size = data.shape[2:4] + elif data.ndim == 5: + patch_size = data.shape[2:5] else: raise ValueError("only support 2D/3D batch data.") else: @@ -357,7 +352,7 @@ def __call__(self, **data_dict): p_el_per_sample=self.p_el_per_sample, p_scale_per_sample=self.p_scale_per_sample, p_rot_per_sample=self.p_rot_per_sample, independent_scale_for_each_axis=self.independent_scale_for_each_axis, - p_rot_per_axis=self.p_rot_per_axis, + p_rot_per_axis=self.p_rot_per_axis, p_independent_scale_per_axis=self.p_independent_scale_per_axis) data_dict[self.data_key] = ret_val[0] if seg is not None: @@ -419,7 +414,8 @@ def __init__(self, patch_size, patch_center_dist_from_border=30, do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, data_key="data", label_key="seg", p_el_per_sample=1, p_scale_per_sample=1, p_rot_per_sample=1, - independent_scale_for_each_axis=False, p_rot_per_axis:float=1, p_independent_scale_per_axis: float=1): + independent_scale_for_each_axis=False, p_rot_per_axis: float = 1, + p_independent_scale_per_axis: float = 1): self.p_rot_per_sample = p_rot_per_sample self.p_scale_per_sample = p_scale_per_sample self.p_el_per_sample = p_el_per_sample @@ -451,9 +447,9 @@ def __call__(self, **data_dict): seg = data_dict.get(self.label_key) if self.patch_size is None: - if len(data.shape) == 4: + if data.ndim == 4: patch_size = (data.shape[2], data.shape[3]) - elif len(data.shape) == 5: + elif data.ndim == 5: patch_size = (data.shape[2], data.shape[3], data.shape[4]) else: raise ValueError("only support 2D/3D batch data.") @@ -471,9 +467,9 @@ def __call__(self, **data_dict): order_seg=self.order_seg, random_crop=self.random_crop, p_el_per_sample=self.p_el_per_sample, p_scale_per_sample=self.p_scale_per_sample, p_rot_per_sample=self.p_rot_per_sample, - independent_scale_for_each_axis=self.independent_scale_for_each_axis, - p_rot_per_axis=self.p_rot_per_axis, - p_independent_scale_per_axis=self.p_independent_scale_per_axis) + independent_scale_for_each_axis=self.independent_scale_for_each_axis, + p_rot_per_axis=self.p_rot_per_axis, + p_independent_scale_per_axis=self.p_independent_scale_per_axis) data_dict[self.data_key] = ret_val[0] if seg is not None: @@ -526,6 +522,7 @@ def __call__(self, **data_dict): data_dict[self.label_key] = seg return data_dict + class AnatomyInformedTransform(AbstractTransform): """ The data augmentation is presented at MICCAI 2023 in the proceedings of 'Anatomy-informed Data Augmentation for enhanced Prostate Cancer Detection'. @@ -545,8 +542,9 @@ class AnatomyInformedTransform(AbstractTransform): `max_annotation_value`: the value that should be still relevant for the main task `replace_value`: segmentation values larger than the `max_annotation_value` will be replaced with """ + def __init__(self, dil_ranges, modalities, directions_of_trans, p_per_sample, - spacing_ratio=0.3125/3.0, blur=32, anisotropy_safety= True, + spacing_ratio=0.3125 / 3.0, blur=32, anisotropy_safety=True, max_annotation_value=1, replace_value=0): self.dil_ranges = dil_ranges self.modalities = modalities @@ -569,6 +567,7 @@ def __call__(self, **data_dict): self.dim = 3 active_organs = [] + # TODO: Optimize this for prob in self.p_per_sample: if np.random.uniform() < prob: active_organs.append(1) @@ -576,17 +575,18 @@ def __call__(self, **data_dict): active_organs.append(0) for b in range(data_shape[0]): - data_dict['data'][b, :, :, :, :], data_dict['seg'][b, 0, :, :, :] = augment_anatomy_informed(data=data_dict['data'][b, :, :, :, :], - seg=data_dict['seg'][b, 0, :, :, :], - active_organs=active_organs, - dilation_ranges=self.dil_ranges, - directions_of_trans=self.directions_of_trans, - modalities=self.modalities, - spacing_ratio=self.spacing_ratio, - blur=self.blur, - anisotropy_safety=self.anisotropy_safety, - max_annotation_value=self.max_annotation_value, - replace_value=self.replace_value) + data_dict['data'][b, :, :, :, :], data_dict['seg'][b, 0, :, :, :] = augment_anatomy_informed( + data=data_dict['data'][b, :, :, :, :], + seg=data_dict['seg'][b, 0, :, :, :], + active_organs=active_organs, + dilation_ranges=self.dil_ranges, + directions_of_trans=self.directions_of_trans, + modalities=self.modalities, + spacing_ratio=self.spacing_ratio, + blur=self.blur, + anisotropy_safety=self.anisotropy_safety, + max_annotation_value=self.max_annotation_value, + replace_value=self.replace_value) return data_dict @@ -670,9 +670,10 @@ def __init__(self, data_key="data", label_key="seg", self.border_cval_seg = border_cval_seg def __call__(self, **data_dict): - data = data_dict.get(self.data_key) + data = data_dict[self.data_key] seg = data_dict.get(self.label_key) + # TODO if data.shape[1] < 2: raise ValueError("only support multi-modal images") else: diff --git a/batchgenerators/transforms/utility_transforms.py b/batchgenerators/transforms/utility_transforms.py index 0f04be0..e6392be 100644 --- a/batchgenerators/transforms/utility_transforms.py +++ b/batchgenerators/transforms/utility_transforms.py @@ -14,12 +14,13 @@ # limitations under the License. import copy -from typing import List, Type, Union, Tuple +from typing import List, Union, Tuple import numpy as np +import torch -from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding, \ - convert_seg_to_bounding_box_coordinates, transpose_channels, \ +from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding_batched +from batchgenerators.augmentations.utils import convert_seg_to_bounding_box_coordinates, transpose_channels, \ ignore_anatomy from batchgenerators.transforms.abstract_transforms import AbstractTransform @@ -35,42 +36,61 @@ def __init__(self, keys=None, cast_to=None): if keys is not None and not isinstance(keys, (list, tuple)): keys = [keys] self.keys = keys - self.cast_to = cast_to - - def cast(self, tensor): - if self.cast_to is not None: - if self.cast_to == 'half': - tensor = tensor.half() - elif self.cast_to == 'float': - tensor = tensor.float() - elif self.cast_to == 'long': - tensor = tensor.long() - elif self.cast_to == 'bool': - tensor = tensor.bool() - else: - raise ValueError('Unknown value for cast_to: %s' % self.cast_to) - return tensor - def __call__(self, **data_dict): - import torch + if cast_to is None: + self.cast = self.no_cast + elif cast_to == 'half': + self.cast = self.half_cast + elif cast_to == 'float': + self.cast = self.float_cast + elif cast_to == 'long': + self.cast = self.long_cast + elif cast_to == 'bool': + self.cast = self.bool_cast + else: + raise ValueError(f'Unknown value for cast_to: {cast_to}') + + def cast(self, x: np.ndarray) -> torch.Tensor: + pass + + @staticmethod + def no_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).contiguous() + + @staticmethod + def float_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.float, memory_format=torch.contiguous_format) + @staticmethod + def long_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.long, memory_format=torch.contiguous_format) + + @staticmethod + def bool_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.bool, memory_format=torch.contiguous_format) + + @staticmethod + def half_cast(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(x).to(torch.half, memory_format=torch.contiguous_format) + + def __call__(self, **data_dict): if self.keys is None: for key, val in data_dict.items(): if isinstance(val, np.ndarray): - data_dict[key] = self.cast(torch.from_numpy(val)).contiguous() + data_dict[key] = self.cast(val) elif isinstance(val, (list, tuple)) and all([isinstance(i, np.ndarray) for i in val]): - data_dict[key] = [self.cast(torch.from_numpy(i)).contiguous() for i in val] + data_dict[key] = [self.cast(i) for i in val] else: for key in self.keys: if isinstance(data_dict[key], np.ndarray): - data_dict[key] = self.cast(torch.from_numpy(data_dict[key])).contiguous() - elif isinstance(data_dict[key], (list, tuple)) and all([isinstance(i, np.ndarray) for i in data_dict[key]]): - data_dict[key] = [self.cast(torch.from_numpy(i)).contiguous() for i in data_dict[key]] + data_dict[key] = self.cast(data_dict[key]) + elif isinstance(data_dict[key], (list, tuple)) and all( + [isinstance(i, np.ndarray) for i in data_dict[key]]): + data_dict[key] = [self.cast(i) for i in data_dict[key]] return data_dict - class ListToNumpy(AbstractTransform): """Utility function for pytorch. Converts data (and seg) numpy ndarrays to pytorch tensors """ @@ -119,9 +139,7 @@ def __init__(self, classes, seg_channel=0, output_key="seg"): def __call__(self, **data_dict): seg = data_dict.get("seg") if seg is not None: - new_seg = np.zeros([seg.shape[0], len(self.classes)] + list(seg.shape[2:]), dtype=seg.dtype) - for b in range(seg.shape[0]): - new_seg[b] = convert_seg_image_to_one_hot_encoding(seg[b, self.seg_channel], self.classes) + new_seg = convert_seg_image_to_one_hot_encoding_batched(seg[:, self.seg_channel], self.classes) data_dict[self.output_key] = new_seg else: from warnings import warn @@ -140,9 +158,9 @@ def __call__(self, **data_dict): seg = data_dict.get("seg") if seg is not None: new_seg = np.zeros([seg.shape[0], len(self.classes) * seg.shape[1]] + list(seg.shape[2:]), dtype=seg.dtype) - for b in range(seg.shape[0]): - for c in range(seg.shape[1]): - new_seg[b, c*len(self.classes):(c+1)*len(self.classes)] = convert_seg_image_to_one_hot_encoding(seg[b, c], self.classes) + for c in range(seg.shape[1]): + new_seg[:, c * len(self.classes):(c + 1) * len(self.classes)] = \ + convert_seg_image_to_one_hot_encoding_batched(seg[:, c], self.classes) data_dict["seg"] = new_seg else: from warnings import warn @@ -201,13 +219,14 @@ def __call__(self, **data_dict): if seg is not None: if not seg.shape[1] % self.output_channels == 0: from warnings import warn - warn("Calling ConvertMultiSegToArgmaxTransform but number of input channels {} cannot be divided into {} output channels.".format(seg.shape[1], self.output_channels)) + warn( + f"Calling ConvertMultiSegToArgmaxTransform but number of input channels {seg.shape[1]} cannot be divided into {self.output_channels} output channels.") n_labels = seg.shape[1] // self.output_channels target_size = list(seg.shape) target_size[1] = self.output_channels output = np.zeros(target_size, dtype=seg.dtype) for i in range(self.output_channels): - output[:, i] = np.argmax(seg[:, i*n_labels:(i+1)*n_labels], 1) + output[:, i] = np.argmax(seg[:, i * n_labels:(i + 1) * n_labels], 1) if self.labels is not None: if list(self.labels) != list(range(n_labels)): for index, value in enumerate(reversed(self.labels)): @@ -231,7 +250,8 @@ def __init__(self, dim, get_rois_from_seg_flag=False, class_specific_seg_flag=Fa self.class_specific_seg_flag = class_specific_seg_flag def __call__(self, **data_dict): - data_dict = convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.get_rois_from_seg_flag, class_specific_seg_flag=self.class_specific_seg_flag) + data_dict = convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.get_rois_from_seg_flag, + class_specific_seg_flag=self.class_specific_seg_flag) return data_dict @@ -239,6 +259,7 @@ class MoveSegToDataChannel(AbstractTransform): """ concatenates data_dict['seg'] to data_dict['data'] """ + def __call__(self, **data_dict): data_dict['data'] = np.concatenate((data_dict['data'], data_dict['seg']), axis=1) return data_dict @@ -349,7 +370,7 @@ def __call__(self, **data_dict): return new_dict def __repr__(self): - return str(type(self).__name__) + " ( " + repr(self.transforms) + " )" + return f"{str(type(self).__name__)} ( {repr(self.transforms)} )" class ReshapeTransform(AbstractTransform): @@ -416,12 +437,12 @@ def __call__(self, **data_dict): inp = data_dict.get(self.input_key) outp = data_dict.get(self.output_key) - assert inp is not None, "input_key %s is not present in data_dict" % self.input_key + assert inp is not None, f"input_key {self.input_key} is not present in data_dict" selected_channels = inp[:, self.channel_indexes] if outp is None: - #warn("output key %s is not present in dict, it will be created" % self.output_key) + # warn("output key %s is not present in dict, it will be created" % self.output_key) outp = selected_channels data_dict[self.output_key] = outp else: @@ -449,13 +470,13 @@ def __call__(self, **data_dict): if data is None: print("WARNING in ConvertToChannelLastTransform: data_dict has no key named", k) else: - if len(data.shape) == 4: + if data.ndim == 4: new_ordering = (0, 2, 3, 1) - elif len(data.shape) == 5: + elif data.ndim == 5: new_ordering = (0, 2, 3, 4, 1) else: raise RuntimeError("unsupported dimensionality for ConvertToChannelLastTransform:", - len(data.shape), + data.ndim, ". Only 2d (b, c, x, y) and 3d (b, c, x, y, z) are supported for now.") assert isinstance(data, np.ndarray), "data_dict[k] must be a numpy array" data = data.transpose(new_ordering) @@ -498,7 +519,7 @@ def __call__(self, **data_dict): # expected to have the same length some_value = data_dict.get(self.relevant_keys[0]) for b in range(len(some_value)): - new_dict = {i: data_dict[i][b:b+1] for i in self.relevant_keys} + new_dict = {i: data_dict[i][b:b + 1] for i in self.relevant_keys} random_transform = np.random.choice(len(self.list_of_transforms), p=self.p) ret = self.list_of_transforms[random_transform](**new_dict) for i in self.relevant_keys: diff --git a/batchgenerators/utilities/custom_types.py b/batchgenerators/utilities/custom_types.py index 2a4b9b1..0098dfe 100644 --- a/batchgenerators/utilities/custom_types.py +++ b/batchgenerators/utilities/custom_types.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Tuple, Any, Callable -import numpy as np +from typing import Union, Tuple, Callable +import numpy as np ScalarType = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]] diff --git a/requirements.txt b/requirements.txt index f4e40a2..741b18d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ numpy>=1.10.2 scipy scikit-image scikit-learn -unittest2 \ No newline at end of file +pandas \ No newline at end of file diff --git a/setup.py b/setup.py index a62dc17..bb9abc4 100755 --- a/setup.py +++ b/setup.py @@ -10,14 +10,12 @@ license='Apache License Version 2.0, January 2004', packages=find_packages(exclude=["tests"]), install_requires=[ - "pillow>=7.1.2", "numpy>=1.10.2", "scipy", "scikit-image", "scikit-learn", - "future", - "unittest2", - "threadpoolctl" + "threadpoolctl", + "pandas" ], keywords=['data augmentation', 'deep learning', 'image segmentation', 'image classification', 'medical image analysis', 'medical image segmentation'], diff --git a/tests/test_DataLoader.py b/tests/test_DataLoader.py index 1fa59e7..21efebe 100644 --- a/tests/test_DataLoader.py +++ b/tests/test_DataLoader.py @@ -200,6 +200,10 @@ def test_return_incomplete_multi_threaded(self): self.assertTrue(len(np.unique(all_return)) == len(data)) def test_thoroughly(self): + really_test_this = False + if not really_test_this: + print("This test takes too much time. Run me if you really want to test me.") + return data_list = [list(range(123)), list(range(1243)), list(range(1)), diff --git a/tests/test_axis_mirroring.py b/tests/test_axis_mirroring.py index 78839e4..8710de7 100644 --- a/tests/test_axis_mirroring.py +++ b/tests/test_axis_mirroring.py @@ -14,7 +14,6 @@ # limitations under the License. import unittest -import unittest2 import numpy as np from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from skimage import data @@ -23,7 +22,7 @@ from batchgenerators.transforms.spatial_transforms import MirrorTransform -class TestMirrorAxis(unittest2.TestCase): +class TestMirrorAxis(unittest.TestCase): def setUp(self): self.seed = 1234 diff --git a/tests/test_color_augmentations.py b/tests/test_color_augmentations.py index 5ea5802..19000e7 100644 --- a/tests/test_color_augmentations.py +++ b/tests/test_color_augmentations.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive,\ +from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive, \ augment_brightness_multiplicative, augment_gamma @@ -26,12 +26,39 @@ def setUp(self): self.data_3D = np.random.random((2, 64, 56, 48)) self.data_2D = np.random.random((2, 64, 56)) self.factor = (0.75, 1.25) + self.data_4D = np.random.random((9, 2, 64, 56, 48)) + self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=True, per_channel=True, + batched=True) + self.d_4D = augment_contrast(self.data_4D, contrast_range=self.factor, preserve_range=False, per_channel=False, + batched=True) + self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=True, per_channel=True) self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=False) + self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=True, per_channel=True) self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=False) - def test_augment_contrast_3D(self): + def test_augment_contrast_4D(self): + data = self.data_4D[0] + mean = np.mean(data) + + idx0 = np.where(data < mean) # where the data is lower than mean value + idx1 = np.where(data > mean) # where the data is greater than mean value + + contrast_lower_limit_0 = self.factor[1] * (data[idx0] - mean) + mean + contrast_lower_limit_1 = self.factor[0] * (data[idx1] - mean) + mean + contrast_upper_limit_0 = self.factor[0] * (data[idx0] - mean) + mean + contrast_upper_limit_1 = self.factor[1] * (data[idx1] - mean) + mean + + # augmented values lower than mean should be lower than lower limit and greater than upper limit + self.assertTrue(np.all(np.logical_and(self.d_4D[0][idx0] >= contrast_lower_limit_0, + self.d_4D[0][idx0] <= contrast_upper_limit_0)), + "Augmented contrast below mean value not within range") + # augmented values greater than mean should be lower than upper limit and greater than lower limit + self.assertTrue(np.all(np.logical_and(self.d_4D[0][idx1] >= contrast_lower_limit_1, + self.d_4D[0][idx1] <= contrast_upper_limit_1)), + "Augmented contrast above mean not within range") + def test_augment_contrast_3D(self): mean = np.mean(self.data_3D) idx0 = np.where(self.data_3D < mean) # where the data is lower than mean value @@ -52,7 +79,6 @@ def test_augment_contrast_3D(self): "Augmented contrast above mean not within range") def test_augment_contrast_2D(self): - mean = np.mean(self.data_2D) idx0 = np.where(self.data_2D < mean) # where the data is lower than mean value @@ -80,7 +106,7 @@ def setUp(self): self.data_input_3D = np.random.random((2, 64, 56, 48)) self.data_input_2D = np.random.random((2, 64, 56)) self.factor = (0.75, 1.25) - self.multiplier_range = [2,4] + self.multiplier_range = [2, 4] self.d_3D_per_channel = augment_brightness_additive(np.copy(self.data_input_3D), mu=100, sigma=10, per_channel=True) @@ -103,8 +129,8 @@ def setUp(self): multiplier_range=self.multiplier_range, per_channel=False) def test_augment_brightness_additive_3D(self): - add_factor = self.d_3D-self.data_input_3D - self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1, + add_factor = self.d_3D - self.data_input_3D + self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == 1, "Added brightness factor is not equal for all channels") add_factor = self.d_3D_per_channel - self.data_input_3D @@ -112,8 +138,8 @@ def test_augment_brightness_additive_3D(self): "Added brightness factor is not different for each channels") def test_augment_brightness_additive_2D(self): - add_factor = self.d_2D-self.data_input_2D - self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1, + add_factor = self.d_2D - self.data_input_2D + self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == 1, "Added brightness factor is not equal for all channels") add_factor = self.d_2D_per_channel - self.data_input_2D @@ -121,23 +147,41 @@ def test_augment_brightness_additive_2D(self): "Added brightness factor is not different for each channels") def test_augment_brightness_multiplicative_3D(self): - mult_factor = self.d_3D_mult/self.data_input_3D - self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1, + mult_factor = self.d_3D_mult / self.data_input_3D + self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == 1, "Multiplied brightness factor is not equal for all channels") - mult_factor = self.d_3D_per_channel_mult/self.data_input_3D + mult_factor = self.d_3D_per_channel_mult / self.data_input_3D self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_3D.shape[0], "Multiplied brightness factor is not different for each channels") def test_augment_brightness_multiplicative_2D(self): - mult_factor = self.d_2D_mult/self.data_input_2D - self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1, + mult_factor = self.d_2D_mult / self.data_input_2D + self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == 1, "Multiplied brightness factor is not equal for all channels") - mult_factor = self.d_2D_per_channel_mult/self.data_input_2D + mult_factor = self.d_2D_per_channel_mult / self.data_input_2D self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_2D.shape[0], "Multiplied brightness factor is not different for each channels") + def test_batched_augment_brightness_multiplicative(self): + data = np.random.random((9, 2, 64, 56, 48)) + result_1 = augment_brightness_multiplicative(np.copy(data), + multiplier_range=self.multiplier_range, + per_channel=False, + batched=True) + result_2 = augment_brightness_multiplicative(np.copy(data), + multiplier_range=self.multiplier_range, + per_channel=True, + batched=True) + + mult_factor = result_1 / data + self.assertEqual(len(np.unique(mult_factor.round(decimals=6))), data.shape[0], + "Multiplied brightness factor per sample is not equal for all channels") + + mult_factor = result_2 / data + self.assertEqual(len(np.unique(mult_factor.round(decimals=6))), data.shape[0] * data.shape[1], + "Multiplied brightness factor per sample is not different for all channels") class TestAugmentGamma(unittest.TestCase): @@ -146,7 +190,10 @@ def setUp(self): self.data_input_3D = np.random.random((2, 64, 56, 48)) self.data_input_2D = np.random.random((2, 64, 56)) - self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=False) + self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=True, + retain_stats=True) + self.d_3D = augment_gamma(np.copy(self.data_input_3D), gamma_range=(0.2, 1.2), per_channel=False, + retain_stats=False) def test_augment_gamma_3D(self): self.assertTrue(self.d_3D.min().round(decimals=3) == self.data_input_3D.min().round(decimals=3) and diff --git a/tests/test_crop.py b/tests/test_crop.py index 3933a96..d6bbe2e 100644 --- a/tests/test_crop.py +++ b/tests/test_crop.py @@ -268,7 +268,7 @@ def test_pad_nd_image_and_seg_2D(self): print('Zero padding with new_shape.shape smaller than data.shape. [DONE]') print('Zero padding with new_shape.shape bigger than data.shape. [START]') - self.assertRaises(IndexError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) + self.assertRaises(ValueError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) print('Zero padding with new_shape.shape bigger than data.shape. [DONE]') print('Padding to bigger output shape in all dimensions with constant_value=1 for segmentation padding . [START]') @@ -352,7 +352,7 @@ def test_pad_nd_image_and_seg_3D(self): print('Zero padding with new_shape.shape smaller than data.shape. [DONE]') print('Zero padding with new_shape.shape bigger than data.shape. [START]') - self.assertRaises(IndexError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) + self.assertRaises(ValueError, pad_nd_image_and_seg, data, seg, new_shape=new_shape6) print('Zero padding with new_shape.shape bigger than data.shape. [DONE]') print('Padding to bigger output shape in all dimensions with constant_value=1 for segmentation padding . [START]') diff --git a/tests/test_multithreaded_augmenter.py b/tests/test_multithreaded_augmenter.py index 35aaf14..09cb5a5 100644 --- a/tests/test_multithreaded_augmenter.py +++ b/tests/test_multithreaded_augmenter.py @@ -187,7 +187,8 @@ def test_image_pipeline_and_pin_memory(self): res = mt.next() assert isinstance(res['data'], torch.Tensor) - assert res['data'].is_pinned() + if torch.cuda.is_available(): + assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) diff --git a/tests/test_normalizations.py b/tests/test_normalizations.py index b8d1a1a..ae4896b 100644 --- a/tests/test_normalizations.py +++ b/tests/test_normalizations.py @@ -17,7 +17,7 @@ import numpy as np from batchgenerators.augmentations.normalizations import range_normalization, zero_mean_unit_variance_normalization, \ - cut_off_outliers + cut_off_outliers, mean_std_normalization class TestNormalization(unittest.TestCase): @@ -230,6 +230,33 @@ def test_cut_off_outliers_whole_image(self): print('Test test_cut_off_outliers_whole_image. [START]') + def test_mean_std_normalization_per_channel(self): + print('Test test_mean_std_normalization_per_channel. [START]') + data = np.random.random((32, 4, 64, 56, 48)) + + mean = [np.mean(data[:, i]) for i in range(4)] + std = [np.std(data[:, i]) for i in range(4)] + data_normalized = mean_std_normalization(data, mean, std, per_channel=True) + + for i in range(4): + self.assertAlmostEqual(data_normalized[:, i].mean(), 0.0) + self.assertAlmostEqual(data_normalized[:, i].std(), 1.0) + + print('Test test_mean_std_normalization_per_channel. [DONE]') + + def test_mean_std_normalization_whole_image(self): + print('Test test_mean_std_normalization_whole_image. [START]') + data = np.random.random((32, 4, 64, 56, 48)) + + mean = np.mean(data) + std = np.std(data) + data_normalized = mean_std_normalization(data, mean, std, per_channel=False) + + self.assertAlmostEqual(data_normalized.mean(), 0.0) + self.assertAlmostEqual(data_normalized.std(), 1.0) + + print('Test test_mean_std_normalization_whole_image. [DONE]') + if __name__ == '__main__': unittest.main()