Skip to content

[Feature suggestion] Native Support for Multiple Starred Wildcards with Flexible Dimensions in rearrange #387

@LeonhardFeiner

Description

@LeonhardFeiner

Feature Request: Native Support for Multiple Starred Wildcards with Flexible Dimensions in einops.rearrange

Problem

Currently, writing flexible tensor rearrangement code that seamlessly handles inputs with variable inner dimensions—such as both 2D images and 3D volumes—is cumbersome with einops.rearrange. Patterns require explicitly enumerating all dimension axes and sizes, leading to verbose and fragile code.

Proposed Solution

Add native support for multiple starred wildcards like (*x) and *y in einops.rearrange patterns, where each starred group can have dimensions partly or fully flexible by using -1 to mark inferable sizes.

  • Parenthesized starred groups (*x) allow exactly one flexible dimension (-1).
  • Non-parenthesized starred groups *y allow multiple flexible -1 dimensions.
  • Users pass shape hints via keyword arguments, e.g., x=(-1, 3), y=(-1, -1).
  • The function automatically infers the sizes of flexible dimensions from the tensor.

Example

import numpy as np

tensor = np.random.randn(5, 7, 12, 5, 6, 1, 13)
pattern = "b r (*x) *y ... -> (...) b *x r (*y)"
kwargs = dict(x=(-1, 3), y=(-1, -1))

result = rearrange_star(tensor, pattern, **kwargs)
print(result.shape) # e.g., (13, 5, 4, 3, 7, 30)

This syntax greatly simplifies writing functions that operate on variable-dimension data, such as supporting both 2D images (batch, channel, height, width) and 3D volumes (batch, channel, depth, height, width) without pattern explosion or repeated code.

Benefits

  • Concise expression of complex, flexible tensor structures.
  • Eliminates boilerplate for dimension letter management and shape expansions.
  • Easier to maintain and extend code for multi-modal data pipelines.

Reference Implementation

import re
import numpy as np
from einops import rearrange

def rearrange_star(tensor, pattern, **kwargs):
  """
  Extended rearrange function that supports inner complex shapes with * operator.
  Example:
      t = np.random.randn(5,7,12,5,6, 1,13)
      r = rearrange_star(t, "b r (*x) *y ... -> (...) b *x r (*y)", x=(4,3), y=(5,6))
      # r.shape will be (13, 5, 4, 3, 7, 30)
  
  Args:
      tensor: Input tensor to rearrange
      pattern: Einops pattern with starred variables like (*x), *y, etc.
      **kwargs: Shape definitions for starred variables (e.g., x=(4,3), y=(5,6))
  
  Returns:
      Rearranged tensor
  """
  lhs_pattern = pattern.split("->")[0].strip()
  
  # Pattern to match starred variables: (*var) or *var
  star_pattern = r"$$\*(\w+)$$|\*(\w+)"
  matches = re.findall(star_pattern, lhs_pattern)
  
  starred_vars = {left_match or right_match for left_match, right_match in matches}
  grouped_starred_vars = {left_match for left_match, _ in matches if left_match}
  
  # Check all starred variables have shape definitions
  missing_shapes = starred_vars - set(kwargs.keys())
  if missing_shapes:
      raise ValueError(f"Missing shape definitions for starred variables: {missing_shapes}")
  
  new_kwargs = dict(kwargs)
  transformed_pattern = pattern
  for var_name in starred_vars:
      shape = kwargs[var_name]
  
      # I am not sure whether we should allow this or remove this if
      # it would accept also an integer instead of a tuple for *var
      if isinstance(shape, (int, np.integer)):
          shape = (shape,)
  
      expanded_axes = {f"new_{var_name}_{i}": s for i, s in enumerate(shape)}
      explicit_dims = {k: v for k, v in expanded_axes.items() if v != -1}
      shape_str = " ".join(expanded_axes.keys())
  
      num_inferable_dims = len(expanded_axes) - len(explicit_dims)
  
      if num_inferable_dims > 1 and var_name in grouped_starred_vars:
          raise ValueError(
              f"Starred variable '{var_name}' with parentheses can have at most one dimension as -1."
          )
  
      transformed_pattern = re.sub(
          f"\$$\\*{var_name}\$$", f"({shape_str})", transformed_pattern
      )
      transformed_pattern = re.sub(f"\\*{var_name}", shape_str, transformed_pattern)
  
      new_kwargs.pop(var_name)
      new_kwargs.update(explicit_dims)
  
  return rearrange(tensor, transformed_pattern, **new_kwargs)

Additional Notes

The starred wildcard groups concept shares similarities with einops.unpack, where contiguous variable-length dimension groups are handled together. Using -1 to represent flexible, automatically inferred dimension sizes aligns well with common reshape operations across many frameworks.

Looking forward, this flexible starred wildcard API could naturally extend beyond rearrange to other core einops operations like reduce, repeat and parse_shape. Providing a unified interface for expressing variable inner shapes across these functions would simplify flexible tensor manipulation and better support multi-modal data processing involving varying dimension counts, such as 2D images versus 3D volumes.

Integrating this across the einops API would enhance expressiveness and user convenience for a wide range of tensor transformation tasks.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions