Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions batchgeneratorsv2/transforms/nnunet/random_binary_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import torch
from fft_conv_pytorch import fft_conv
from torch_fftconv import fft_conv2d, fft_conv3d

from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
Expand All @@ -30,11 +30,12 @@ def binary_dilation_torch(input_tensor, structure_element):
raise ValueError("Input tensor must be 2D (X, Y) or 3D (X, Y, Z).")

# Perform the convolution
# if num_dims == 2: # 2D convolution
# output = F.conv2d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same')
# elif num_dims == 3: # 3D convolution
# output = F.conv3d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same')
output = torch.round(fft_conv(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same'), decimals=0)
input_tensor = input_tensor.unsqueeze(0).unsqueeze(0)
if num_dims == 2: # 2D convolution
output = fft_conv2d(input_tensor, structure_element, padding='same')
elif num_dims == 3: # 3D convolution
output = fft_conv3d(input_tensor, structure_element, padding='same')
output = torch.round(output, decimals=0)

# Threshold to get binary output
output = output > 0
Expand Down
5 changes: 3 additions & 2 deletions batchgeneratorsv2/transforms/noise/gaussian_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
from fft_conv_pytorch import fft_conv
from torch_fftconv import fft_conv1d, fft_conv2d, fft_conv3d


def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
Expand All @@ -31,8 +31,9 @@ def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_

# Dynamically set up padding, convolution operation, and kernel shape based on the number of spatial dimensions
conv_ops = {1: conv1d, 2: conv2d, 3: conv3d}
fft_conv_ops = {1: fft_conv1d, 2: fft_conv2d, 3: fft_conv3d}
if force_use_fft is not None:
conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv
conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv_ops[spatial_dims]
else:
conv_op = conv_ops[spatial_dims]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ keywords = [
dependencies = [
"torch>=2.0.0",
"numpy",
"fft-conv-pytorch",
"torch-fftconv",
"batchgenerators>=0.25"
]

Expand Down