Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 6292ce0

Browse files
authored
Speedup in fpixel.grayscale_to_multichannel (#2564)
1 parent 1cc07b7 commit 6292ce0

File tree

5 files changed

+48
-8
lines changed

5 files changed

+48
-8
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ repos:
7070
language: system
7171
files: setup.py
7272
- repo: https://github.com/astral-sh/ruff-pre-commit
73-
rev: v0.11.13
73+
rev: v0.12.0
7474
hooks:
7575
- id: ruff
7676
exclude: '__pycache__/'
@@ -98,7 +98,7 @@ repos:
9898
hooks:
9999
- id: pyproject-fmt
100100
- repo: https://github.com/pre-commit/mirrors-mypy
101-
rev: v1.16.0
101+
rev: v1.16.1
102102
hooks:
103103
- id: mypy
104104
files: ^albumentations/

albumentations/augmentations/geometric/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,7 @@ def get_params_dependent_on_data(
14651465
for tile in tiles
14661466
],
14671467
).reshape(
1468-
self.num_grid_xy[::-1] + (4,),
1468+
(*self.num_grid_xy[::-1], 4),
14691469
) # Reshape to (grid_height, grid_width, 4)
14701470

14711471
polygons = fgeometric.generate_distorted_grid_polygons(

albumentations/augmentations/pixel/functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,8 +1647,8 @@ def grayscale_to_multichannel(
16471647
return grayscale_image
16481648

16491649
squeezed = np.squeeze(grayscale_image)
1650-
# For multi-channel output, stack channels
1651-
return np.stack([squeezed] * num_output_channels, axis=-1)
1650+
# For multi-channel output, use tile for better performance
1651+
return np.tile(squeezed[..., np.newaxis], (1,) * squeezed.ndim + (num_output_channels,))
16521652

16531653

16541654
@preserve_channel_dim
@@ -2519,7 +2519,7 @@ def generate_noise(
25192519
height, width = shape[:2]
25202520
reduced_height = max(1, int(height * approximation))
25212521
reduced_width = max(1, int(width * approximation))
2522-
reduced_shape = (reduced_height, reduced_width) + shape[2:]
2522+
reduced_shape = (reduced_height, reduced_width, *shape[2:])
25232523

25242524
# Generate noise at reduced resolution
25252525
if spatial_mode == "shared":
@@ -3482,7 +3482,7 @@ def prepare_drop_values(
34823482
return np.full(array.shape, values[0], dtype=array.dtype)
34833483

34843484
# For multichannel input, broadcast values to full shape
3485-
return np.full(array.shape[:2] + (len(values),), values, dtype=array.dtype)
3485+
return np.full((*array.shape[:2], len(values)), values, dtype=array.dtype)
34863486

34873487

34883488
def get_mask_array(data: dict[str, Any]) -> np.ndarray | None:

albumentations/augmentations/pixel/transforms.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3678,14 +3678,53 @@ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray:
36783678
warnings.warn("The image is already an RGB.", stacklevel=2)
36793679
return np.ascontiguousarray(img)
36803680
if not is_grayscale_image(img):
3681-
msg = "ToRGB transformation expects 2-dim images or 3-dim with the last dimension equal to 1."
3681+
msg = "ToRGB transformation expects images with the number of channels equal to 1."
36823682
raise TypeError(msg)
36833683

36843684
return fpixel.grayscale_to_multichannel(
36853685
img,
36863686
num_output_channels=self.num_output_channels,
36873687
)
36883688

3689+
def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
3690+
"""Apply ToRGB to a batch of images.
3691+
3692+
Args:
3693+
images (np.ndarray): Batch of images with shape (N, H, W, C) or (N, H, W).
3694+
**params (Any): Additional parameters.
3695+
3696+
Returns:
3697+
np.ndarray: Batch of RGB images.
3698+
3699+
"""
3700+
return self.apply(images, **params)
3701+
3702+
def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
3703+
"""Apply ToRGB to a single volume.
3704+
3705+
Args:
3706+
volume (np.ndarray): Volume with shape (D, H, W, C) or (D, H, W).
3707+
**params (Any): Additional parameters.
3708+
3709+
Returns:
3710+
np.ndarray: Grayscale volume.
3711+
3712+
"""
3713+
return self.apply(volume, **params)
3714+
3715+
def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
3716+
"""Apply ToRGB to a batch of volumes.
3717+
3718+
Args:
3719+
volumes (np.ndarray): Batch of volumes with shape (N, D, H, W, C) or (N, D, H, W).
3720+
**params (Any): Additional parameters.
3721+
3722+
Returns:
3723+
np.ndarray: Batch of RGB volumes.
3724+
3725+
"""
3726+
return self.apply(volumes, **params)
3727+
36893728

36903729
class ToSepia(ImageOnlyTransform):
36913730
"""Apply a sepia filter to the input image.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ lint.ignore = [
222222
"FBT002",
223223
"FBT003",
224224
"G004",
225+
"PLC0415",
225226
"PLR0911",
226227
"PLR0913",
227228
"PLR2004",

0 commit comments

Comments
 (0)