-
Notifications
You must be signed in to change notification settings - Fork 390
Description
Feature Request: Native Support for Multiple Starred Wildcards with Flexible Dimensions in einops.rearrange
Problem
Currently, writing flexible tensor rearrangement code that seamlessly handles inputs with variable inner dimensions—such as both 2D images and 3D volumes—is cumbersome with einops.rearrange. Patterns require explicitly enumerating all dimension axes and sizes, leading to verbose and fragile code.
Proposed Solution
Add native support for multiple starred wildcards like (*x) and *y in einops.rearrange patterns, where each starred group can have dimensions partly or fully flexible by using -1 to mark inferable sizes.
- Parenthesized starred groups
(*x)allow exactly one flexible dimension (-1). - Non-parenthesized starred groups
*yallow multiple flexible-1dimensions. - Users pass shape hints via keyword arguments, e.g.,
x=(-1, 3),y=(-1, -1). - The function automatically infers the sizes of flexible dimensions from the tensor.
Example
import numpy as np
tensor = np.random.randn(5, 7, 12, 5, 6, 1, 13)
pattern = "b r (*x) *y ... -> (...) b *x r (*y)"
kwargs = dict(x=(-1, 3), y=(-1, -1))
result = rearrange_star(tensor, pattern, **kwargs)
print(result.shape) # e.g., (13, 5, 4, 3, 7, 30)This syntax greatly simplifies writing functions that operate on variable-dimension data, such as supporting both 2D images (batch, channel, height, width) and 3D volumes (batch, channel, depth, height, width) without pattern explosion or repeated code.
Benefits
- Concise expression of complex, flexible tensor structures.
- Eliminates boilerplate for dimension letter management and shape expansions.
- Easier to maintain and extend code for multi-modal data pipelines.
Reference Implementation
import re
import numpy as np
from einops import rearrange
def rearrange_star(tensor, pattern, **kwargs):
"""
Extended rearrange function that supports inner complex shapes with * operator.
Example:
t = np.random.randn(5,7,12,5,6, 1,13)
r = rearrange_star(t, "b r (*x) *y ... -> (...) b *x r (*y)", x=(4,3), y=(5,6))
# r.shape will be (13, 5, 4, 3, 7, 30)
Args:
tensor: Input tensor to rearrange
pattern: Einops pattern with starred variables like (*x), *y, etc.
**kwargs: Shape definitions for starred variables (e.g., x=(4,3), y=(5,6))
Returns:
Rearranged tensor
"""
lhs_pattern = pattern.split("->")[0].strip()
# Pattern to match starred variables: (*var) or *var
star_pattern = r"$$\*(\w+)$$|\*(\w+)"
matches = re.findall(star_pattern, lhs_pattern)
starred_vars = {left_match or right_match for left_match, right_match in matches}
grouped_starred_vars = {left_match for left_match, _ in matches if left_match}
# Check all starred variables have shape definitions
missing_shapes = starred_vars - set(kwargs.keys())
if missing_shapes:
raise ValueError(f"Missing shape definitions for starred variables: {missing_shapes}")
new_kwargs = dict(kwargs)
transformed_pattern = pattern
for var_name in starred_vars:
shape = kwargs[var_name]
# I am not sure whether we should allow this or remove this if
# it would accept also an integer instead of a tuple for *var
if isinstance(shape, (int, np.integer)):
shape = (shape,)
expanded_axes = {f"new_{var_name}_{i}": s for i, s in enumerate(shape)}
explicit_dims = {k: v for k, v in expanded_axes.items() if v != -1}
shape_str = " ".join(expanded_axes.keys())
num_inferable_dims = len(expanded_axes) - len(explicit_dims)
if num_inferable_dims > 1 and var_name in grouped_starred_vars:
raise ValueError(
f"Starred variable '{var_name}' with parentheses can have at most one dimension as -1."
)
transformed_pattern = re.sub(
f"\$$\\*{var_name}\$$", f"({shape_str})", transformed_pattern
)
transformed_pattern = re.sub(f"\\*{var_name}", shape_str, transformed_pattern)
new_kwargs.pop(var_name)
new_kwargs.update(explicit_dims)
return rearrange(tensor, transformed_pattern, **new_kwargs)Additional Notes
The starred wildcard groups concept shares similarities with einops.unpack, where contiguous variable-length dimension groups are handled together. Using -1 to represent flexible, automatically inferred dimension sizes aligns well with common reshape operations across many frameworks.
Looking forward, this flexible starred wildcard API could naturally extend beyond rearrange to other core einops operations like reduce, repeat and parse_shape. Providing a unified interface for expressing variable inner shapes across these functions would simplify flexible tensor manipulation and better support multi-modal data processing involving varying dimension counts, such as 2D images versus 3D volumes.
Integrating this across the einops API would enhance expressiveness and user convenience for a wide range of tensor transformation tasks.