diff --git a/batchgenerators/augmentations/utils.py b/batchgenerators/augmentations/utils.py index f8aac77..3212b0b 100755 --- a/batchgenerators/augmentations/utils.py +++ b/batchgenerators/augmentations/utils.py @@ -591,11 +591,12 @@ def resize_segmentation(segmentation, new_shape, order=3): :return: ''' tpe = segmentation.dtype - unique_labels = np.unique(segmentation) + assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" if order == 0: return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe) else: + unique_labels = np.unique(segmentation) reshaped = np.zeros(new_shape, dtype=segmentation.dtype) for i, c in enumerate(unique_labels):