Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4fb9071
Avoid no-op DimShuffle
ricardoV94 Jun 20, 2025
5b39df6
Use DimShuffle instead of Reshape in `ix_`
ricardoV94 May 22, 2025
7a7db6f
Extract ViewOp functionality into a base TypeCastOp
ricardoV94 Jun 20, 2025
024136e
Implement basic labeled tensor functionality
ricardoV94 Aug 2, 2023
162b50a
Implement stack for XTensorVariables
ricardoV94 Jun 6, 2025
9676b4e
Implement Elemwise and Blockwise operations for XTensorVariables
ricardoV94 May 26, 2025
1a0226c
Implement cast for XTensorVariables
ricardoV94 Jun 6, 2025
b1b5fde
Implement reduction operations for XTensorVariables
ricardoV94 May 25, 2025
3e4f7ae
Implement concat for XTensorVariables
ricardoV94 May 26, 2025
62d410b
Implement transpose for XTensorVariables
AllenDowney May 28, 2025
bd29b1b
Implement unstack for XTensorVariables
OriolAbril May 22, 2025
e6da6c5
Implement index for XTensorVariables
ricardoV94 May 21, 2025
e2a87db
Implement index update for XTensorVariables
ricardoV94 Jun 2, 2025
fc5f668
Implement diff for XTensorVariables
ricardoV94 May 26, 2025
4235d63
Implement squeeze for XTensorVariables
AllenDowney Jun 6, 2025
7a9db22
Implement expand_dims for XTensorVariables (#1449)
AllenDowney Jun 13, 2025
cc28cb0
Implement dot for XTensorVariables (#1475)
AllenDowney Jun 19, 2025
0be448e
Implement XTensorVariable version of RandomVariables
ricardoV94 Jun 20, 2025
2a91f58
Add implementation of broadcast for xtensor
AllenDowney Jun 20, 2025
4f02b39
Add xtensor broadcast
AllenDowney Jun 20, 2025
ea04c9e
Handling symbolic dims
AllenDowney Jun 21, 2025
9a255b6
Adding broadcast_like
AllenDowney Jun 21, 2025
b793a9a
Retreat to basic implementation of broadcast
AllenDowney Jun 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 73 additions & 41 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Transpose,
UnStack,
XBroadcast,
XBroadcastLike,
)


Expand Down Expand Up @@ -170,51 +171,53 @@ def lower_expand_dims(fgraph, node):
return [result]


def build_dimension_shapes(nodes):
"""Precompute the best available shape for each dimension."""
dim_shapes = {}
all_dims = set().union(*(x.type.dims for x in nodes))

for dim in all_dims:
concrete_shape = None
runtime_shape = None

for x in nodes:
if dim in x.type.dims:
idx = x.type.dims.index(dim)
if x.type.shape[idx] is not None:
concrete_shape = as_tensor(x.type.shape[idx], dtype="int64")
break # Found concrete shape, stop looking
elif runtime_shape is None:
# Remember first runtime shape as fallback
x_tensor = tensor_from_xtensor(x)
runtime_shape = shape(x_tensor)[idx]

# Store the best available shape
if concrete_shape is not None:
dim_shapes[dim] = concrete_shape
elif runtime_shape is not None:
dim_shapes[dim] = runtime_shape
else:
dim_shapes[dim] = pt.constant(1, dtype="int64")

return dim_shapes


def get_broadcast_shape_for_dim(dim_shapes, dim, out_shape):
"""Get the broadcast shape for a single dimension."""
if out_shape is not None:
broadcast_shape = as_tensor(out_shape, dtype="int64")
else:
broadcast_shape = dim_shapes[dim]
return broadcast_shape


@register_lower_xtensor
@node_rewriter(tracks=[XBroadcast])
def lower_broadcast(fgraph, node):
"""Rewrite XBroadcast to tensor operations with symbolic shape support."""

def build_dimension_shapes():
"""Precompute the best available shape for each dimension."""
dim_shapes = {}
all_dims = set().union(*(x.type.dims for x in node.inputs))

for dim in all_dims:
concrete_shape = None
runtime_shape = None

for x in node.inputs:
if dim in x.type.dims:
idx = x.type.dims.index(dim)
if x.type.shape[idx] is not None:
concrete_shape = as_tensor(x.type.shape[idx], dtype="int64")
break # Found concrete shape, stop looking
elif runtime_shape is None:
# Remember first runtime shape as fallback
x_tensor = tensor_from_xtensor(x)
runtime_shape = shape(x_tensor)[idx]

# Store the best available shape
if concrete_shape is not None:
dim_shapes[dim] = concrete_shape
elif runtime_shape is not None:
dim_shapes[dim] = runtime_shape
else:
dim_shapes[dim] = pt.constant(1, dtype="int64")

return dim_shapes

# Precompute all dimension shapes once
dim_shapes = build_dimension_shapes()

def get_broadcast_shape_for_dim(dim, out_shape):
"""Get the broadcast shape for a single dimension."""
if out_shape is not None:
broadcast_shape = as_tensor(out_shape, dtype="int64")
else:
broadcast_shape = dim_shapes[dim]
return broadcast_shape
# Precompute all dimension shapes
dim_shapes = build_dimension_shapes(node.inputs)

result_tensors = []
for x, out in zip(node.inputs, node.outputs):
Expand All @@ -227,7 +230,7 @@ def get_broadcast_shape_for_dim(dim, out_shape):

# Get the broadcast shape for each dimension
broadcast_shape = [
get_broadcast_shape_for_dim(dim, out_shape)
get_broadcast_shape_for_dim(dim_shapes, dim, out_shape)
for dim, out_shape in zip(out.type.dims, out.type.shape)
]

Expand All @@ -238,3 +241,32 @@ def get_broadcast_shape_for_dim(dim, out_shape):
result_tensors.append(result)

return result_tensors


@register_lower_xtensor
@node_rewriter(tracks=[XBroadcastLike])
def lower_broadcast_like(fgraph, node):
"""Rewrite XBroadcastLike to tensor operations with symbolic shape support."""

# Precompute all dimension shapes
dim_shapes = build_dimension_shapes(node.inputs)

[x, _] = node.inputs
[out] = node.outputs
x_tensor = tensor_from_xtensor(x)

# Dimshuffle the tensor to the output dimensions
x_dim_to_idx = {d: i for i, d in enumerate(x.type.dims)}
shuffle_pattern = [x_dim_to_idx.get(d, "x") for d in out.type.dims]
x_tensor = x_tensor.dimshuffle(shuffle_pattern)

# Get the broadcast shape for each dimension
broadcast_shape = [
get_broadcast_shape_for_dim(dim_shapes, dim, out_shape)
for dim, out_shape in zip(out.type.dims, out.type.shape)
]

x_tensor = broadcast_to(x_tensor, broadcast_shape)

result = xtensor_from_tensor(x_tensor, dims=out.type.dims)
return [result]
60 changes: 53 additions & 7 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,61 @@ def make_node(self, *inputs):
return Apply(self, inputs, outputs)


def broadcast(*args, exclude: str | Sequence[str] | None = None):
"""Broadcast any number of XTensorVariables against each other."""

# Normalize exclude
def _normalize_exclude(exclude):
"""Normalize the exclude parameter to a tuple of hashable dimension names."""
if exclude is None:
exclude = ()
return ()
elif isinstance(exclude, str):
exclude = (exclude,)
return (exclude,)
elif isinstance(exclude, Sequence):
exclude = tuple(exclude)
return tuple(exclude)
else:
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")


def broadcast(*args, exclude: str | Sequence[str] | None = None):
"""Broadcast any number of XTensorVariables against each other."""
exclude = _normalize_exclude(exclude)
return XBroadcast(exclude=exclude)(*args)


class XBroadcastLike(XOp):
"""Broadcast this tensor against another XTensorVariable."""

__props__ = ("exclude",)

def __init__(self, exclude: tuple[Hashable, ...] = ()):
if not all(isinstance(dim, Hashable) for dim in exclude):
raise TypeError("All items in `exclude` must be hashable dimension names.")

self.exclude = exclude

def make_node(self, x, other):
x = as_xtensor(x)
other = as_xtensor(other)

# Get the union shape for included dims,
# preserving order from other first, then x
dims_and_shape = combine_dims_and_shape([other, x])

broadcast_dims = tuple(d for d in dims_and_shape if d not in self.exclude)

broadcast_shape = tuple(dims_and_shape[d] for d in broadcast_dims)
dtype = upcast(*[x.type.dtype, other.type.dtype])

# Preserve excluded dims from this input
excluded_dims = [d for d in x.type.dims if d in self.exclude]
excluded_shapes = [dims_and_shape[d] for d in excluded_dims]

output = xtensor(
dtype=dtype,
shape=broadcast_shape + tuple(excluded_shapes),
dims=broadcast_dims + tuple(excluded_dims),
)
return Apply(self, [x, other], [output])


def broadcast_like(x, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
exclude = _normalize_exclude(exclude)
return XBroadcastLike(exclude=exclude)(x, other)
4 changes: 4 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)

def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
return px.shape.broadcast_like(self, other, exclude=exclude)


class XTensorConstantSignature(TensorConstantSignature):
pass
Expand Down
83 changes: 82 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_broadcast_errors():
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))

with pytest.raises(TypeError, match="not iterable"):
with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"):
broadcast(x, y, z, exclude=1)

# Test with conflicting shapes
Expand All @@ -544,3 +544,84 @@ def test_broadcast_errors():

with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"):
broadcast(x, y, z)


def test_broadcast_like():
"""Test broadcast_like method"""
# Create test data
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))

x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)

# Basic broadcasting
x2_expected = x_test.broadcast_like(y_test)
x2 = x.broadcast_like(y)
fn = xr_function([x, y], x2)
x2_result = fn(x_test, y_test)
xr_assert_allclose(x2_result, x2_expected)

y2_expected = y_test.broadcast_like(z_test)
y2 = y.broadcast_like(z)
fn = xr_function([y, z], y2)
y2_result = fn(y_test, z_test)
xr_assert_allclose(y2_result, y2_expected)

# Test with excluded dims
x2_expected = x_test.broadcast_like(y_test, exclude=["b"])
x2 = x.broadcast_like(y, exclude=["b"])
fn = xr_function([x, y], x2)
x2_result = fn(x_test, y_test)
xr_assert_allclose(x2_result, x2_expected)

y2_expected = y_test.broadcast_like(z_test, exclude=["b", "c"])
y2 = y.broadcast_like(z, exclude=["b"])
fn = xr_function([y, z], y2)
y2_result = fn(y_test, z_test)
xr_assert_allclose(y2_result, y2_expected)

# Test with symbolic sizes
x = xtensor("x", dims=("a", "b"), shape=(None, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, None))
z = xtensor("z", dims=("b", "d"), shape=(None, 6))

x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4)))
y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6)))
z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6)))

x2_expected = x_test.broadcast_like(y_test)
x2 = x.broadcast_like(y)
fn = xr_function([x, y], x2)
x2_result = fn(x_test, y_test)
xr_assert_allclose(x2_result, x2_expected)

y2_expected = y_test.broadcast_like(z_test, exclude=["b", "c"])
y2 = y.broadcast_like(z, exclude=["b"])
fn = xr_function([y, z], y2)
y2_result = fn(y_test, z_test)
xr_assert_allclose(y2_result, y2_expected)


def test_broadcast_like_errors():
"""Test error handling in broadcast_like."""
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))

with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"):
x.broadcast_like(y, exclude=1)

with pytest.raises(
TypeError, match="All items in `exclude` must be hashable dimension names."
):
x.broadcast_like(y, exclude=[np.array([1, 2])])

# Test with conflicting shapes
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 7))

with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"):
y.broadcast_like(z)