Skip to content

Commit 17d2efb

Browse files
authored
Fix gaussian noise augmentation and add random gaussian blur (#4508)
* reimplement Gaussian noise * add RandomGaussianBlur aug * minor fix| * fix unit tests * reply comments
1 parent 46a9395 commit 17d2efb

38 files changed

+194
-96
lines changed

src/otx/data/transform_libs/torchvision.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from scipy.stats import truncnorm
2727
from torchvision import tv_tensors
2828
from torchvision._utils import sequence_to_str
29+
from torchvision.transforms.v2 import GaussianBlur, GaussianNoise
2930
from torchvision.transforms.v2 import functional as F # noqa: N812
3031

3132
from otx.data.entity.base import (
@@ -903,6 +904,58 @@ def __repr__(self) -> str:
903904
return repr_str
904905

905906

907+
class RandomGaussianBlur(GaussianBlur):
908+
"""Modified version of the torchvision GaussianBlur."""
909+
910+
def __init__(
911+
self,
912+
kernel_size: int | Sequence[int],
913+
sigma: int | tuple[float, float] = (0.1, 2.0),
914+
prob: float = 0.5,
915+
) -> None:
916+
super().__init__(kernel_size=kernel_size, sigma=sigma)
917+
self.prob = prob
918+
919+
def transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> torch.Tensor:
920+
"""Main transform function."""
921+
if self.prob >= np.random.rand():
922+
return super().transform(inpt, params)
923+
return inpt
924+
925+
926+
class RandomGaussianNoise(GaussianNoise):
927+
"""Modified version of the torchvision GaussianNoise.
928+
929+
This augmentation allows to add gaussian noise to unscaled image.
930+
Only float32 images are supported for this augmentation.
931+
"""
932+
933+
def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip: bool = True, prob: float = 0.5) -> None:
934+
super().__init__(mean=mean, sigma=sigma, clip=clip)
935+
self.prob = prob
936+
937+
def _is_scaled(self, tensor: torch.Tensor) -> bool:
938+
return torch.max(tensor) <= 1 + 1e-5
939+
940+
def forward(self, *_inputs: OTXDataItem) -> OTXDataItem:
941+
"""Main transform function."""
942+
assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101
943+
inputs = _inputs[0]
944+
if (img := getattr(inputs, "image", None)) is not None and self.prob >= np.random.rand():
945+
scaled = self._is_scaled(img)
946+
sigma = self.sigma * 255 if not scaled else self.sigma
947+
mean = self.mean * 255 if not scaled else self.mean
948+
clip = False if not scaled else self.clip
949+
950+
img = self._call_kernel(F.gaussian_noise, img, mean=mean, sigma=sigma, clip=clip)
951+
if not scaled:
952+
img = torch.clamp(img, 0, 255)
953+
954+
inputs.image = img
955+
956+
return inputs
957+
958+
906959
class PhotoMetricDistortion(tvt_v2.Transform, NumpytoTVTensorMixin):
907960
"""Implementation of mmdet.datasets.transforms.PhotoMetricDistortion with torchvision format.
908961

src/otx/recipe/_base_/data/classification.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,21 @@ train_subset:
2626
is_numpy_to_tvtensor: true
2727
- class_path: torchvision.transforms.v2.RandomVerticalFlip
2828
enable: false
29-
- class_path: torchvision.transforms.v2.GaussianBlur
29+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
3030
enable: false
3131
init_args:
3232
kernel_size: 5
3333
- class_path: torchvision.transforms.v2.ToDtype
3434
init_args:
3535
dtype: ${as_torch_dtype:torch.float32}
3636
scale: false
37+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
38+
enable: false
3739
- class_path: torchvision.transforms.v2.Normalize
3840
init_args:
3941
mean: [123.675, 116.28, 103.53]
4042
std: [58.395, 57.12, 57.375]
41-
- class_path: torchvision.transforms.v2.GaussianNoise
42-
enable: false
43+
4344
sampler:
4445
class_path: otx.data.samplers.balanced_sampler.BalancedSampler
4546

src/otx/recipe/_base_/data/detection.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,20 @@ train_subset:
2727
is_numpy_to_tvtensor: true
2828
- class_path: torchvision.transforms.v2.RandomVerticalFlip
2929
enable: false
30-
- class_path: torchvision.transforms.v2.GaussianBlur
30+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
3131
enable: false
3232
init_args:
3333
kernel_size: 5
3434
- class_path: torchvision.transforms.v2.ToDtype
3535
init_args:
3636
dtype: ${as_torch_dtype:torch.float32}
37+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
38+
enable: false
3739
- class_path: torchvision.transforms.v2.Normalize
3840
init_args:
3941
mean: [0.0, 0.0, 0.0]
4042
std: [255.0, 255.0, 255.0]
41-
- class_path: torchvision.transforms.v2.GaussianNoise
42-
enable: false
43+
4344
sampler:
4445
class_path: torch.utils.data.RandomSampler
4546

src/otx/recipe/_base_/data/detection_tile.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,20 @@ train_subset:
3030
is_numpy_to_tvtensor: true
3131
- class_path: torchvision.transforms.v2.RandomVerticalFlip
3232
enable: false
33-
- class_path: torchvision.transforms.v2.GaussianBlur
33+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
3434
enable: false
3535
init_args:
3636
kernel_size: 5
3737
- class_path: torchvision.transforms.v2.ToDtype
3838
init_args:
3939
dtype: ${as_torch_dtype:torch.float32}
40+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
41+
enable: false
4042
- class_path: torchvision.transforms.v2.Normalize
4143
init_args:
4244
mean: [0.0, 0.0, 0.0]
4345
std: [255.0, 255.0, 255.0]
44-
- class_path: torchvision.transforms.v2.GaussianNoise
45-
enable: false
46+
4647
sampler:
4748
class_path: torch.utils.data.RandomSampler
4849

src/otx/recipe/_base_/data/instance_segmentation.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,20 @@ train_subset:
3333
is_numpy_to_tvtensor: true
3434
- class_path: torchvision.transforms.v2.RandomVerticalFlip
3535
enable: false
36-
- class_path: torchvision.transforms.v2.GaussianBlur
36+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
3737
enable: false
3838
init_args:
3939
kernel_size: 5
4040
- class_path: torchvision.transforms.v2.ToDtype
4141
init_args:
4242
dtype: ${as_torch_dtype:torch.float32}
43+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
44+
enable: false
4345
- class_path: torchvision.transforms.v2.Normalize
4446
init_args:
4547
mean: [123.675, 116.28, 103.53]
4648
std: [58.395, 57.12, 57.375]
47-
- class_path: torchvision.transforms.v2.GaussianNoise
48-
enable: false
49+
4950
sampler:
5051
class_path: torch.utils.data.RandomSampler
5152

src/otx/recipe/_base_/data/keypoint_detection.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ train_subset:
1818
- class_path: torchvision.transforms.v2.ToDtype
1919
init_args:
2020
dtype: ${as_torch_dtype:torch.float32}
21+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
22+
enable: false
2123
- class_path: torchvision.transforms.v2.Normalize
2224
init_args:
2325
mean: [123.675, 116.28, 103.53]

src/otx/recipe/_base_/data/semantic_segmentation.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ train_subset:
3232
- class_path: torchvision.transforms.v2.ToDtype
3333
init_args:
3434
dtype: ${as_torch_dtype:torch.float32}
35+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
36+
enable: false
3537
- class_path: torchvision.transforms.v2.Normalize
3638
init_args:
3739
mean: [123.675, 116.28, 103.53]

src/otx/recipe/_base_/data/semantic_segmentation_tile.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ train_subset:
3434
- class_path: torchvision.transforms.v2.ToDtype
3535
init_args:
3636
dtype: ${as_torch_dtype:torch.float32}
37+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
38+
enable: false
3739
- class_path: torchvision.transforms.v2.Normalize
3840
init_args:
3941
mean: [123.675, 116.28, 103.53]

src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,16 @@ overrides:
6868
is_numpy_to_tvtensor: true
6969
- class_path: torchvision.transforms.v2.RandomVerticalFlip
7070
enable: false
71-
- class_path: torchvision.transforms.v2.GaussianBlur
71+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
7272
enable: false
7373
init_args:
7474
kernel_size: 5
75-
- class_path: torchvision.transforms.v2.GaussianNoise
76-
enable: false
7775
- class_path: torchvision.transforms.v2.ToDtype
7876
init_args:
7977
dtype: ${as_torch_dtype:torch.float32}
8078
scale: false
79+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
80+
enable: false
8181
- class_path: torchvision.transforms.v2.Normalize
8282
init_args:
8383
mean: [123.675, 116.28, 103.53]

src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ overrides:
7272
is_numpy_to_tvtensor: true
7373
- class_path: torchvision.transforms.v2.RandomVerticalFlip
7474
enable: false
75-
- class_path: torchvision.transforms.v2.GaussianBlur
75+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
7676
enable: false
7777
init_args:
7878
kernel_size: 5
79-
- class_path: torchvision.transforms.v2.GaussianNoise
80-
enable: false
8179
- class_path: torchvision.transforms.v2.ToDtype
8280
init_args:
8381
dtype: ${as_torch_dtype:torch.float32}
8482
scale: false
83+
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
84+
enable: false
8585
- class_path: torchvision.transforms.v2.Normalize
8686
init_args:
8787
mean: [123.675, 116.28, 103.53]

0 commit comments

Comments
 (0)