Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 3db139d

Browse files
authored
add apply_to_images (#2560)
* add apply_to_images * test for generate_noise * Refactor generate_noise tests to avoid conditionals using pytest.mark.parametrize * Fix processing for batches * optimized add_noise for batch * removed unused tests * removed unused tests
1 parent 76fef70 commit 3db139d

File tree

4 files changed

+50
-13
lines changed

4 files changed

+50
-13
lines changed

albumentations/augmentations/pixel/functional.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2301,7 +2301,10 @@ def add_noise(img: np.ndarray, noise: np.ndarray) -> np.ndarray:
23012301
np.ndarray: The noise added to the image.
23022302
23032303
"""
2304-
return add(img, noise, inplace=False)
2304+
n_tiles = np.prod(img.shape) // np.prod(noise.shape)
2305+
noise = np.tile(noise, (n_tiles,) + (1,) * noise.ndim).reshape(img.shape)
2306+
2307+
return add_array(img, noise, inplace=False)
23052308

23062309

23072310
def slic(

albumentations/augmentations/pixel/transforms.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
MAX_VALUES_BY_DTYPE,
2222
NUM_MULTI_CHANNEL_DIMENSIONS,
2323
batch_transform,
24+
get_image_data,
2425
get_num_channels,
2526
is_grayscale_image,
2627
is_rgb_image,
@@ -2631,6 +2632,39 @@ def apply(
26312632
"""
26322633
return fpixel.add_noise(img, noise_map)
26332634

2635+
def apply_to_images(self, images: np.ndarray, noise_map: np.ndarray, **params: Any) -> np.ndarray:
2636+
"""Apply the Gaussian noise to a batch of images.
2637+
2638+
Args:
2639+
images (np.ndarray): The batch of images to apply the Gaussian noise to.
2640+
noise_map (np.ndarray): The noise map to apply to the images.
2641+
**params (Any): Additional parameters (not used in this transform).
2642+
2643+
"""
2644+
return fpixel.add_noise(images, noise_map)
2645+
2646+
def apply_to_volume(self, volume: np.ndarray, noise_map: np.ndarray, **params: Any) -> np.ndarray:
2647+
"""Apply the Gaussian noise to a single volume.
2648+
2649+
Args:
2650+
volume (np.ndarray): The volume to apply the Gaussian noise to.
2651+
noise_map (np.ndarray): The noise map to apply to the volume.
2652+
**params (Any): Additional parameters (not used in this transform).
2653+
2654+
"""
2655+
return fpixel.add_noise(volume, noise_map)
2656+
2657+
def apply_to_volumes(self, volumes: np.ndarray, noise_map: np.ndarray, **params: Any) -> np.ndarray:
2658+
"""Apply the Gaussian noise to a batch of volumes.
2659+
2660+
Args:
2661+
volumes (np.ndarray): The batch of volumes to apply the Gaussian noise to.
2662+
noise_map (np.ndarray): The noise map to apply to the volumes.
2663+
**params (Any): Additional parameters (not used in this transform).
2664+
2665+
"""
2666+
return fpixel.add_noise(volumes, noise_map)
2667+
26342668
def get_params_dependent_on_data(
26352669
self,
26362670
params: dict[str, Any],
@@ -2647,17 +2681,17 @@ def get_params_dependent_on_data(
26472681
- "noise_map" (np.ndarray): The noise map to apply to the image.
26482682
26492683
"""
2650-
image = data["image"] if "image" in data else data["images"][0]
2651-
max_value = MAX_VALUES_BY_DTYPE[image.dtype]
2684+
metadata = get_image_data(data)
2685+
max_value = MAX_VALUES_BY_DTYPE[metadata["dtype"]]
2686+
shape = (metadata["height"], metadata["width"], metadata["num_channels"])
26522687

26532688
sigma = self.py_random.uniform(*self.std_range)
2654-
26552689
mean = self.py_random.uniform(*self.mean_range)
26562690

26572691
noise_map = fpixel.generate_noise(
26582692
noise_type="gaussian",
26592693
spatial_mode="per_pixel" if self.per_channel else "shared",
2660-
shape=image.shape,
2694+
shape=shape,
26612695
params={"mean_range": (mean, mean), "std_range": (sigma, sigma)},
26622696
max_value=max_value,
26632697
approximation=self.noise_scale_factor,
@@ -6036,14 +6070,14 @@ def get_params_dependent_on_data(
60366070
data (dict[str, Any]): The data to apply the transform to.
60376071
60386072
"""
6039-
image = data["image"] if "image" in data else data["images"][0]
6040-
6041-
max_value = MAX_VALUES_BY_DTYPE[image.dtype]
6073+
metadata = get_image_data(data)
6074+
max_value = MAX_VALUES_BY_DTYPE[metadata["dtype"]]
6075+
shape = (metadata["height"], metadata["width"], metadata["num_channels"])
60426076

60436077
noise_map = fpixel.generate_noise(
60446078
noise_type=self.noise_type,
60456079
spatial_mode=self.spatial_mode,
6046-
shape=image.shape,
6080+
shape=shape,
60476081
params=self.noise_params,
60486082
max_value=max_value,
60496083
approximation=self.approximation,
@@ -6180,7 +6214,7 @@ class SaltAndPepper(ImageOnlyTransform):
61806214
"""Apply salt and pepper noise to the input image.
61816215
61826216
Salt and pepper noise is a form of impulse noise that randomly sets pixels to either maximum value (salt)
6183-
or minimum value (pepper). The amount and proportion of salt vs pepper noise can be controlled.
6217+
or minimum value (pepper). The amount and proportion of salt vs pepper can be controlled.
61846218
The same noise mask is applied to all channels of the image to preserve color consistency.
61856219
61866220
Args:
@@ -6283,8 +6317,7 @@ def get_params_dependent_on_data(
62836317
data (dict[str, Any]): The data to apply the transform to.
62846318
62856319
"""
6286-
image = data["image"] if "image" in data else data["images"][0]
6287-
height, width = image.shape[:2]
6320+
height, width = params["shape"][:2]
62886321

62896322
total_amount = self.py_random.uniform(*self.amount)
62906323
salt_ratio = self.py_random.uniform(*self.salt_vs_pepper)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"PyYAML",
99
"typing-extensions>=4.9.0; python_version<'3.10'",
1010
"pydantic>=2.9.2",
11-
"albucore==0.0.26",
11+
"albucore==0.0.28",
1212
"eval-type-backport; python_version<'3.10'",
1313
]
1414

tests/functional/test_functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tests.utils import convert_2d_to_target_format
2424
from copy import deepcopy
2525
from sklearn.decomposition import NMF
26+
from typing import Any
2627

2728

2829
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)