Skip to content

Commit e921915

Browse files
Implement broadcast for XTensorVariables
Co-authored-by: Ricardo <[email protected]>
1 parent e1ce1c3 commit e921915

File tree

6 files changed

+311
-4
lines changed

6 files changed

+311
-4
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg, random
55
from pytensor.xtensor.math import dot
6-
from pytensor.xtensor.shape import concat
6+
from pytensor.xtensor.shape import broadcast, concat
77
from pytensor.xtensor.type import (
88
as_xtensor,
99
xtensor,

pytensor/xtensor/rewriting/shape.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytensor.tensor as pt
12
from pytensor.graph import node_rewriter
23
from pytensor.tensor import (
34
broadcast_to,
@@ -11,6 +12,7 @@
1112
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
1213
from pytensor.xtensor.rewriting.utils import lower_aligned
1314
from pytensor.xtensor.shape import (
15+
Broadcast,
1416
Concat,
1517
ExpandDims,
1618
Squeeze,
@@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node):
157159
# Convert result back to xtensor
158160
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
159161
return [result]
162+
163+
164+
@register_lower_xtensor
165+
@node_rewriter(tracks=[Broadcast])
166+
def lower_broadcast(fgraph, node):
167+
"""Rewrite XBroadcast using tensor operations."""
168+
169+
excluded_dims = node.op.exclude
170+
171+
tensor_inputs = [
172+
lower_aligned(inp, out.type.dims)
173+
for inp, out in zip(node.inputs, node.outputs, strict=True)
174+
]
175+
176+
if not excluded_dims:
177+
# Simple case: All dimensions are broadcasted
178+
tensor_outputs = pt.broadcast_arrays(*tensor_inputs)
179+
180+
else:
181+
# Complex case: Some dimensions are excluded from broadcasting
182+
# Pick the first dimension_length for each dim
183+
broadcast_dims = {
184+
d: None for d in node.outputs[0].type.dims if d not in excluded_dims
185+
}
186+
for xtensor_inp in node.inputs:
187+
for dim, dim_length in xtensor_inp.sizes.items():
188+
if dim in broadcast_dims and broadcast_dims[dim] is None:
189+
# If the dimension is not excluded, set its shape
190+
broadcast_dims[dim] = dim_length
191+
assert not any(
192+
value is None for value in broadcast_dims.values()
193+
), "All dimensions must have a length"
194+
195+
# Create zeros with the broadcast dimensions, to then broadcast each input against
196+
# PyTensor will rewrite into using only the shapes of the zeros tensor
197+
broadcast_dims = pt.zeros(
198+
tuple(broadcast_dims.values()),
199+
dtype=node.outputs[0].type.dtype,
200+
)
201+
n_broadcast_dims = broadcast_dims.ndim
202+
203+
tensor_outputs = []
204+
for tensor_inp, xtensor_out in zip(tensor_inputs, node.outputs, strict=True):
205+
n_excluded_dims = tensor_inp.type.ndim - n_broadcast_dims
206+
# Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims
207+
# second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor
208+
tensor_outputs.append(
209+
pt.second(
210+
pt.shape_padright(broadcast_dims, n_excluded_dims),
211+
tensor_inp,
212+
)
213+
)
214+
215+
new_outs = [
216+
xtensor_from_tensor(out_tensor, dims=out.type.dims)
217+
for out_tensor, out in zip(tensor_outputs, node.outputs)
218+
]
219+
return new_outs

pytensor/xtensor/shape.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from pytensor.tensor.type import integer_dtypes
1414
from pytensor.tensor.utils import get_static_shape_from_size_variables
1515
from pytensor.xtensor.basic import XOp
16-
from pytensor.xtensor.type import as_xtensor, xtensor
16+
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
17+
from pytensor.xtensor.vectorization import combine_dims_and_shape
1718

1819

1920
class Stack(XOp):
@@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
504505
x = Transpose(dims=tuple(target_dims))(x)
505506

506507
return x
508+
509+
510+
class Broadcast(XOp):
511+
"""Broadcast multiple XTensorVariables against each other."""
512+
513+
__props__ = ("exclude",)
514+
515+
def __init__(self, exclude: Sequence[str] = ()):
516+
self.exclude = tuple(exclude)
517+
518+
def make_node(self, *inputs):
519+
inputs = [as_xtensor(x) for x in inputs]
520+
521+
exclude = self.exclude
522+
dims_and_shape = combine_dims_and_shape(inputs, exclude=exclude)
523+
524+
broadcast_dims = tuple(dims_and_shape.keys())
525+
broadcast_shape = tuple(dims_and_shape.values())
526+
dtype = upcast(*[x.type.dtype for x in inputs])
527+
528+
outputs = []
529+
for x in inputs:
530+
x_dims = x.type.dims
531+
x_shape = x.type.shape
532+
# The output has excluded dimensions in the order they appear in the op argument
533+
excluded_dims = tuple(d for d in exclude if d in x_dims)
534+
excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims)
535+
536+
output = xtensor(
537+
dtype=dtype,
538+
shape=broadcast_shape + excluded_shape,
539+
dims=broadcast_dims + excluded_dims,
540+
)
541+
outputs.append(output)
542+
543+
return Apply(self, inputs, outputs)
544+
545+
546+
def broadcast(
547+
*args, exclude: str | Sequence[str] | None = None
548+
) -> tuple[XTensorVariable, ...]:
549+
"""Broadcast any number of XTensorVariables against each other.
550+
551+
Parameters
552+
----------
553+
*args : XTensorVariable
554+
The tensors to broadcast against each other.
555+
exclude : str or Sequence[str] or None, optional
556+
"""
557+
if not args:
558+
return ()
559+
560+
if exclude is None:
561+
exclude = ()
562+
elif isinstance(exclude, str):
563+
exclude = (exclude,)
564+
elif not isinstance(exclude, Sequence):
565+
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
566+
# xarray broadcast always returns a tuple, even if there's only one tensor
567+
return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore

pytensor/xtensor/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,15 @@ def dot(self, other, dim=None):
736736
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
737737
return px.math.dot(self, other, dim=dim)
738738

739+
def broadcast(self, *others, exclude=None):
740+
"""Broadcast this tensor against other XTensorVariables."""
741+
return px.shape.broadcast(self, *others, exclude=exclude)
742+
743+
def broadcast_like(self, other, exclude=None):
744+
"""Broadcast this tensor against another XTensorVariable."""
745+
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
746+
return self_bcast
747+
739748

740749
class XTensorConstantSignature(TensorConstantSignature):
741750
pass

pytensor/xtensor/vectorization.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from itertools import chain
23

34
import numpy as np
@@ -13,13 +14,22 @@
1314
get_static_shape_from_size_variables,
1415
)
1516
from pytensor.xtensor.basic import XOp
16-
from pytensor.xtensor.type import as_xtensor, xtensor
17+
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
1718

1819

19-
def combine_dims_and_shape(inputs):
20+
def combine_dims_and_shape(
21+
inputs: Sequence[XTensorVariable], exclude: Sequence[str] | None = None
22+
) -> dict[str, int | None]:
23+
"""Combine information of static dimensions and shapes from multiple xtensor inputs.
24+
25+
Exclude
26+
"""
27+
exclude_set: set[str] = set() if exclude is None else set(exclude)
2028
dims_and_shape: dict[str, int | None] = {}
2129
for inp in inputs:
2230
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
31+
if dim in exclude_set:
32+
continue
2333
if dim not in dims_and_shape:
2434
dims_and_shape[dim] = dim_length
2535
elif dim_length is not None:

tests/xtensor/test_shape.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import numpy as np
1111
from xarray import DataArray
12+
from xarray import broadcast as xr_broadcast
1213
from xarray import concat as xr_concat
1314

1415
from pytensor.tensor import scalar
1516
from pytensor.xtensor.shape import (
17+
broadcast,
1618
concat,
1719
stack,
1820
unstack,
@@ -466,3 +468,168 @@ def test_expand_dims_errors():
466468
# Test with a numpy array as dim (not supported)
467469
with pytest.raises(TypeError, match="unhashable type"):
468470
y.expand_dims(np.array([1, 2]))
471+
472+
473+
class TestBroadcast:
474+
@pytest.mark.parametrize(
475+
"exclude",
476+
[
477+
None,
478+
[],
479+
["b"],
480+
["b", "d"],
481+
["a", "d"],
482+
["b", "c", "d"],
483+
["a", "b", "c", "d"],
484+
],
485+
)
486+
def test_compatible_excluded_shapes(self, exclude):
487+
# Create test data
488+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
489+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
490+
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
491+
492+
x_test = xr_arange_like(x)
493+
y_test = xr_arange_like(y)
494+
z_test = xr_arange_like(z)
495+
496+
# Test with excluded dims
497+
x2_expected, y2_expected, z2_expected = xr_broadcast(
498+
x_test, y_test, z_test, exclude=exclude
499+
)
500+
x2, y2, z2 = broadcast(x, y, z, exclude=exclude)
501+
fn = xr_function([x, y, z], [x2, y2, z2])
502+
x2_result, y2_result, z2_result = fn(x_test, y_test, z_test)
503+
504+
xr_assert_allclose(x2_result, x2_expected)
505+
xr_assert_allclose(y2_result, y2_expected)
506+
xr_assert_allclose(z2_result, z2_expected)
507+
508+
def test_incompatible_excluded_shapes(self):
509+
# Test that excluded dims are allowed to be different sizes
510+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
511+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
512+
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
513+
out = broadcast(x, y, z, exclude=["d"])
514+
515+
x_test = xr_arange_like(x)
516+
y_test = xr_arange_like(y)
517+
z_test = xr_arange_like(z)
518+
fn = xr_function([x, y, z], out)
519+
results = fn(x_test, y_test, z_test)
520+
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=["d"])
521+
for res, expected_res in zip(results, expected_results, strict=True):
522+
xr_assert_allclose(res, expected_res)
523+
524+
@pytest.mark.parametrize("exclude", [[], ["b"], ["b", "c"], ["a", "b", "d"]])
525+
def test_runtime_shapes(self, exclude):
526+
x = xtensor("x", dims=("a", "b"), shape=(None, 4))
527+
y = xtensor("y", dims=("c", "d"), shape=(5, None))
528+
z = xtensor("z", dims=("b", "d"), shape=(None, None))
529+
out = broadcast(x, y, z, exclude=exclude)
530+
531+
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4)))
532+
y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6)))
533+
z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6)))
534+
fn = xr_function([x, y, z], out)
535+
results = fn(x_test, y_test, z_test)
536+
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=exclude)
537+
for res, expected_res in zip(results, expected_results, strict=True):
538+
xr_assert_allclose(res, expected_res)
539+
540+
# Test invalid shape raises an error
541+
# Note: We might decide not to raise an error in the lowered graphs for performance reasons
542+
if "d" not in exclude:
543+
z_test_bad = xr_arange_like(xtensor(dims=z.dims, shape=(4, 7)))
544+
with pytest.raises(Exception):
545+
fn(x_test, y_test, z_test_bad)
546+
547+
def test_broadcast_excluded_dims_in_different_order(self):
548+
"""Test broadcasting excluded dims are aligned with user input."""
549+
x = xtensor("x", dims=("a", "c", "b"), shape=(3, 4, 5))
550+
y = xtensor("y", dims=("a", "b", "c"), shape=(3, 5, 4))
551+
out = (out_x, out_y) = broadcast(x, y, exclude=["c", "b"])
552+
assert out_x.type.dims == ("a", "c", "b")
553+
assert out_y.type.dims == ("a", "c", "b")
554+
555+
x_test = xr_arange_like(x)
556+
y_test = xr_arange_like(y)
557+
fn = xr_function([x, y], out)
558+
results = fn(x_test, y_test)
559+
expected_results = xr_broadcast(x_test, y_test, exclude=["c", "b"])
560+
for res, expected_res in zip(results, expected_results, strict=True):
561+
xr_assert_allclose(res, expected_res)
562+
563+
def test_broadcast_errors(self):
564+
"""Test error handling in broadcast."""
565+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
566+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
567+
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
568+
569+
with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"):
570+
broadcast(x, y, z, exclude=1)
571+
572+
# Test with conflicting shapes
573+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
574+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
575+
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
576+
577+
with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"):
578+
broadcast(x, y, z)
579+
580+
def test_broadcast_no_input(self):
581+
assert broadcast() == xr_broadcast()
582+
assert broadcast(exclude=("a",)) == xr_broadcast(exclude=("a",))
583+
584+
def test_broadcast_single_input(self):
585+
"""Test broadcasting a single input."""
586+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
587+
# Broadcast with a single input can still imply a transpose via the exclude parameter
588+
outs = [
589+
*broadcast(x),
590+
*broadcast(x, exclude=("a", "b")),
591+
*broadcast(x, exclude=("b", "a")),
592+
*broadcast(x, exclude=("b",)),
593+
]
594+
595+
fn = xr_function([x], outs)
596+
x_test = xr_arange_like(x)
597+
results = fn(x_test)
598+
expected_results = [
599+
*xr_broadcast(x_test),
600+
*xr_broadcast(x_test, exclude=("a", "b")),
601+
*xr_broadcast(x_test, exclude=("b", "a")),
602+
*xr_broadcast(x_test, exclude=("b",)),
603+
]
604+
for res, expected_res in zip(results, expected_results, strict=True):
605+
xr_assert_allclose(res, expected_res)
606+
607+
@pytest.mark.parametrize("exclude", [None, ["b"], ["b", "c"]])
608+
def test_broadcast_like(self, exclude):
609+
"""Test broadcast_like method"""
610+
# Create test data
611+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
612+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
613+
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
614+
615+
# Order matters so we test both orders
616+
outs = [
617+
x.broadcast_like(y, exclude=exclude),
618+
y.broadcast_like(x, exclude=exclude),
619+
y.broadcast_like(z, exclude=exclude),
620+
z.broadcast_like(y, exclude=exclude),
621+
]
622+
623+
x_test = xr_arange_like(x)
624+
y_test = xr_arange_like(y)
625+
z_test = xr_arange_like(z)
626+
fn = xr_function([x, y, z], outs)
627+
results = fn(x_test, y_test, z_test)
628+
expected_results = [
629+
x_test.broadcast_like(y_test, exclude=exclude),
630+
y_test.broadcast_like(x_test, exclude=exclude),
631+
y_test.broadcast_like(z_test, exclude=exclude),
632+
z_test.broadcast_like(y_test, exclude=exclude),
633+
]
634+
for res, expected_res in zip(results, expected_results, strict=True):
635+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)