2121 MAX_VALUES_BY_DTYPE ,
2222 NUM_MULTI_CHANNEL_DIMENSIONS ,
2323 batch_transform ,
24+ get_image_data ,
2425 get_num_channels ,
2526 is_grayscale_image ,
2627 is_rgb_image ,
@@ -2631,6 +2632,39 @@ def apply(
26312632 """
26322633 return fpixel .add_noise (img , noise_map )
26332634
2635+ def apply_to_images (self , images : np .ndarray , noise_map : np .ndarray , ** params : Any ) -> np .ndarray :
2636+ """Apply the Gaussian noise to a batch of images.
2637+
2638+ Args:
2639+ images (np.ndarray): The batch of images to apply the Gaussian noise to.
2640+ noise_map (np.ndarray): The noise map to apply to the images.
2641+ **params (Any): Additional parameters (not used in this transform).
2642+
2643+ """
2644+ return fpixel .add_noise (images , noise_map )
2645+
2646+ def apply_to_volume (self , volume : np .ndarray , noise_map : np .ndarray , ** params : Any ) -> np .ndarray :
2647+ """Apply the Gaussian noise to a single volume.
2648+
2649+ Args:
2650+ volume (np.ndarray): The volume to apply the Gaussian noise to.
2651+ noise_map (np.ndarray): The noise map to apply to the volume.
2652+ **params (Any): Additional parameters (not used in this transform).
2653+
2654+ """
2655+ return fpixel .add_noise (volume , noise_map )
2656+
2657+ def apply_to_volumes (self , volumes : np .ndarray , noise_map : np .ndarray , ** params : Any ) -> np .ndarray :
2658+ """Apply the Gaussian noise to a batch of volumes.
2659+
2660+ Args:
2661+ volumes (np.ndarray): The batch of volumes to apply the Gaussian noise to.
2662+ noise_map (np.ndarray): The noise map to apply to the volumes.
2663+ **params (Any): Additional parameters (not used in this transform).
2664+
2665+ """
2666+ return fpixel .add_noise (volumes , noise_map )
2667+
26342668 def get_params_dependent_on_data (
26352669 self ,
26362670 params : dict [str , Any ],
@@ -2647,17 +2681,17 @@ def get_params_dependent_on_data(
26472681 - "noise_map" (np.ndarray): The noise map to apply to the image.
26482682
26492683 """
2650- image = data ["image" ] if "image" in data else data ["images" ][0 ]
2651- max_value = MAX_VALUES_BY_DTYPE [image .dtype ]
2684+ metadata = get_image_data (data )
2685+ max_value = MAX_VALUES_BY_DTYPE [metadata ["dtype" ]]
2686+ shape = (metadata ["height" ], metadata ["width" ], metadata ["num_channels" ])
26522687
26532688 sigma = self .py_random .uniform (* self .std_range )
2654-
26552689 mean = self .py_random .uniform (* self .mean_range )
26562690
26572691 noise_map = fpixel .generate_noise (
26582692 noise_type = "gaussian" ,
26592693 spatial_mode = "per_pixel" if self .per_channel else "shared" ,
2660- shape = image . shape ,
2694+ shape = shape ,
26612695 params = {"mean_range" : (mean , mean ), "std_range" : (sigma , sigma )},
26622696 max_value = max_value ,
26632697 approximation = self .noise_scale_factor ,
@@ -6036,14 +6070,14 @@ def get_params_dependent_on_data(
60366070 data (dict[str, Any]): The data to apply the transform to.
60376071
60386072 """
6039- image = data [ "image" ] if "image" in data else data [ "images" ][ 0 ]
6040-
6041- max_value = MAX_VALUES_BY_DTYPE [ image . dtype ]
6073+ metadata = get_image_data ( data )
6074+ max_value = MAX_VALUES_BY_DTYPE [ metadata [ "dtype" ]]
6075+ shape = ( metadata [ "height" ], metadata [ "width" ], metadata [ "num_channels" ])
60426076
60436077 noise_map = fpixel .generate_noise (
60446078 noise_type = self .noise_type ,
60456079 spatial_mode = self .spatial_mode ,
6046- shape = image . shape ,
6080+ shape = shape ,
60476081 params = self .noise_params ,
60486082 max_value = max_value ,
60496083 approximation = self .approximation ,
@@ -6180,7 +6214,7 @@ class SaltAndPepper(ImageOnlyTransform):
61806214 """Apply salt and pepper noise to the input image.
61816215
61826216 Salt and pepper noise is a form of impulse noise that randomly sets pixels to either maximum value (salt)
6183- or minimum value (pepper). The amount and proportion of salt vs pepper noise can be controlled.
6217+ or minimum value (pepper). The amount and proportion of salt vs pepper can be controlled.
61846218 The same noise mask is applied to all channels of the image to preserve color consistency.
61856219
61866220 Args:
@@ -6283,8 +6317,7 @@ def get_params_dependent_on_data(
62836317 data (dict[str, Any]): The data to apply the transform to.
62846318
62856319 """
6286- image = data ["image" ] if "image" in data else data ["images" ][0 ]
6287- height , width = image .shape [:2 ]
6320+ height , width = params ["shape" ][:2 ]
62886321
62896322 total_amount = self .py_random .uniform (* self .amount )
62906323 salt_ratio = self .py_random .uniform (* self .salt_vs_pepper )
0 commit comments