|
26 | 26 | from scipy.stats import truncnorm |
27 | 27 | from torchvision import tv_tensors |
28 | 28 | from torchvision._utils import sequence_to_str |
| 29 | +from torchvision.transforms.v2 import GaussianBlur, GaussianNoise |
29 | 30 | from torchvision.transforms.v2 import functional as F # noqa: N812 |
30 | 31 |
|
31 | 32 | from otx.data.entity.base import ( |
@@ -903,6 +904,58 @@ def __repr__(self) -> str: |
903 | 904 | return repr_str |
904 | 905 |
|
905 | 906 |
|
| 907 | +class RandomGaussianBlur(GaussianBlur): |
| 908 | + """Modified version of the torchvision GaussianBlur.""" |
| 909 | + |
| 910 | + def __init__( |
| 911 | + self, |
| 912 | + kernel_size: int | Sequence[int], |
| 913 | + sigma: int | tuple[float, float] = (0.1, 2.0), |
| 914 | + prob: float = 0.5, |
| 915 | + ) -> None: |
| 916 | + super().__init__(kernel_size=kernel_size, sigma=sigma) |
| 917 | + self.prob = prob |
| 918 | + |
| 919 | + def transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> torch.Tensor: |
| 920 | + """Main transform function.""" |
| 921 | + if self.prob >= np.random.rand(): |
| 922 | + return super().transform(inpt, params) |
| 923 | + return inpt |
| 924 | + |
| 925 | + |
| 926 | +class RandomGaussianNoise(GaussianNoise): |
| 927 | + """Modified version of the torchvision GaussianNoise. |
| 928 | +
|
| 929 | + This augmentation allows to add gaussian noise to unscaled image. |
| 930 | + Only float32 images are supported for this augmentation. |
| 931 | + """ |
| 932 | + |
| 933 | + def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip: bool = True, prob: float = 0.5) -> None: |
| 934 | + super().__init__(mean=mean, sigma=sigma, clip=clip) |
| 935 | + self.prob = prob |
| 936 | + |
| 937 | + def _is_scaled(self, tensor: torch.Tensor) -> bool: |
| 938 | + return torch.max(tensor) <= 1 + 1e-5 |
| 939 | + |
| 940 | + def forward(self, *_inputs: OTXDataItem) -> OTXDataItem: |
| 941 | + """Main transform function.""" |
| 942 | + assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101 |
| 943 | + inputs = _inputs[0] |
| 944 | + if (img := getattr(inputs, "image", None)) is not None and self.prob >= np.random.rand(): |
| 945 | + scaled = self._is_scaled(img) |
| 946 | + sigma = self.sigma * 255 if not scaled else self.sigma |
| 947 | + mean = self.mean * 255 if not scaled else self.mean |
| 948 | + clip = False if not scaled else self.clip |
| 949 | + |
| 950 | + img = self._call_kernel(F.gaussian_noise, img, mean=mean, sigma=sigma, clip=clip) |
| 951 | + if not scaled: |
| 952 | + img = torch.clamp(img, 0, 255) |
| 953 | + |
| 954 | + inputs.image = img |
| 955 | + |
| 956 | + return inputs |
| 957 | + |
| 958 | + |
906 | 959 | class PhotoMetricDistortion(tvt_v2.Transform, NumpytoTVTensorMixin): |
907 | 960 | """Implementation of mmdet.datasets.transforms.PhotoMetricDistortion with torchvision format. |
908 | 961 |
|
|
0 commit comments