diff --git a/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py b/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py index 17a4368..f91e5cf 100644 --- a/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +++ b/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py @@ -3,7 +3,7 @@ import numpy as np import torch -from fft_conv_pytorch import fft_conv +from torch_fftconv import fft_conv2d, fft_conv3d from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform @@ -30,11 +30,12 @@ def binary_dilation_torch(input_tensor, structure_element): raise ValueError("Input tensor must be 2D (X, Y) or 3D (X, Y, Z).") # Perform the convolution - # if num_dims == 2: # 2D convolution - # output = F.conv2d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same') - # elif num_dims == 3: # 3D convolution - # output = F.conv3d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same') - output = torch.round(fft_conv(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same'), decimals=0) + input_tensor = input_tensor.unsqueeze(0).unsqueeze(0) + if num_dims == 2: # 2D convolution + output = fft_conv2d(input_tensor, structure_element, padding='same') + elif num_dims == 3: # 3D convolution + output = fft_conv3d(input_tensor, structure_element, padding='same') + output = torch.round(output, decimals=0) # Threshold to get binary output output = output > 0 diff --git a/batchgeneratorsv2/transforms/noise/gaussian_blur.py b/batchgeneratorsv2/transforms/noise/gaussian_blur.py index 1a707f4..1cfea09 100644 --- a/batchgeneratorsv2/transforms/noise/gaussian_blur.py +++ b/batchgeneratorsv2/transforms/noise/gaussian_blur.py @@ -8,7 +8,7 @@ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform -from fft_conv_pytorch import fft_conv +from torch_fftconv import fft_conv1d, fft_conv2d, fft_conv3d def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6): @@ -31,8 +31,9 @@ def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_ # Dynamically set up padding, convolution operation, and kernel shape based on the number of spatial dimensions conv_ops = {1: conv1d, 2: conv2d, 3: conv3d} + fft_conv_ops = {1: fft_conv1d, 2: fft_conv2d, 3: fft_conv3d} if force_use_fft is not None: - conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv + conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv_ops[spatial_dims] else: conv_op = conv_ops[spatial_dims] diff --git a/pyproject.toml b/pyproject.toml index 00b578e..612021f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ keywords = [ dependencies = [ "torch>=2.0.0", "numpy", - "fft-conv-pytorch", + "torch-fftconv", "batchgenerators>=0.25" ]