Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.shape import broadcast, concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
Expand Down
4 changes: 2 additions & 2 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def softmax(x, dim=None):
return exp_x / exp_x.sum(dim=dim)


class XDot(XOp):
class Dot(XOp):
"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
Expand Down Expand Up @@ -247,6 +247,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
if d not in union:
raise ValueError(f"Dimension {d} not found in either input")

result = XDot(dims=tuple(dim_set))(x, y)
result = Dot(dims=tuple(dim_set))(x, y)

return result
4 changes: 2 additions & 2 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from pytensor.tensor import einsum
from pytensor.tensor.shape import specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.math import Dot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_lower_xtensor
@node_rewriter(tracks=[XDot])
@node_rewriter(tracks=[Dot])
def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot.

Expand Down
71 changes: 62 additions & 9 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytensor.tensor as pt
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
Expand All @@ -9,7 +10,9 @@
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.rewriting.utils import lower_aligned
from pytensor.xtensor.shape import (
Broadcast,
Concat,
ExpandDims,
Squeeze,
Expand Down Expand Up @@ -70,15 +73,7 @@ def lower_concat(fgraph, node):
concat_axis = out_dims.index(concat_dim)

# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]

# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
Expand Down Expand Up @@ -164,3 +159,61 @@ def lower_expand_dims(fgraph, node):
# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]


@register_lower_xtensor
@node_rewriter(tracks=[Broadcast])
def lower_broadcast(fgraph, node):
"""Rewrite XBroadcast using tensor operations."""

excluded_dims = node.op.exclude

tensor_inputs = [
lower_aligned(inp, out.type.dims)
for inp, out in zip(node.inputs, node.outputs, strict=True)
]

if not excluded_dims:
# Simple case: All dimensions are broadcasted
tensor_outputs = pt.broadcast_arrays(*tensor_inputs)

else:
# Complex case: Some dimensions are excluded from broadcasting
# Pick the first dimension_length for each dim
broadcast_dims = {
d: None for d in node.outputs[0].type.dims if d not in excluded_dims
}
for xtensor_inp in node.inputs:
for dim, dim_length in xtensor_inp.sizes.items():
if dim in broadcast_dims and broadcast_dims[dim] is None:
# If the dimension is not excluded, set its shape
broadcast_dims[dim] = dim_length
assert not any(
value is None for value in broadcast_dims.values()
), "All dimensions must have a length"

# Create zeros with the broadcast dimensions, to then broadcast each input against
# PyTensor will rewrite into using only the shapes of the zeros tensor
broadcast_dims = pt.zeros(
tuple(broadcast_dims.values()),
dtype=node.outputs[0].type.dtype,
)
n_broadcast_dims = broadcast_dims.ndim

tensor_outputs = []
for tensor_inp, xtensor_out in zip(tensor_inputs, node.outputs, strict=True):
n_excluded_dims = tensor_inp.type.ndim - n_broadcast_dims
# Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims
# second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor
tensor_outputs.append(
pt.second(
pt.shape_padright(broadcast_dims, n_excluded_dims),
tensor_inp,
)
)

new_outs = [
xtensor_from_tensor(out_tensor, dims=out.type.dims)
for out_tensor, out in zip(tensor_outputs, node.outputs)
]
return new_outs
12 changes: 12 additions & 0 deletions pytensor/xtensor/rewriting/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import typing
from collections.abc import Sequence

from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter, in2out
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.type import XTensorVariable


lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
Expand Down Expand Up @@ -49,3 +54,10 @@ def register(inner_rewriter: RewriteDatabase | NodeRewriter):
**kwargs,
)
return node_rewriter


def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable:
"""Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims"."""
inp_dims = {d: i for i, d in enumerate(x.type.dims)}
ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dims)
return typing.cast(TensorVariable, x.values.dimshuffle(ds_order))
44 changes: 11 additions & 33 deletions pytensor/xtensor/rewriting/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.basic import xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import lower_aligned, register_lower_xtensor
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise


Expand All @@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node):
out_dims = node.outputs[0].type.dims

# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]

tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True
Expand All @@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node):
batch_dims = node.outputs[0].type.dims[:batch_ndim]

# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_inputs = []
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_inputs.append(tensor_inp)
tensor_inputs = [
lower_aligned(inp, batch_dims + core_dims)
for inp, core_dims in zip(node.inputs, op.core_dims[0], strict=True)
]

signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
if signature is None:
Expand Down Expand Up @@ -92,17 +77,10 @@ def lower_rv(fgraph, node):
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]

# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = []
for inp, core_dims in zip(params, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)
tensor_params = [
lower_aligned(inp, param_batch_dims + core_dims)
for inp, core_dims in zip(params, op.core_dims[0], strict=True)
]

size = None
if op.extra_dims:
Expand Down
63 changes: 62 additions & 1 deletion pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
from pytensor.xtensor.vectorization import combine_dims_and_shape


class Stack(XOp):
Expand Down Expand Up @@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
x = Transpose(dims=tuple(target_dims))(x)

return x


class Broadcast(XOp):
"""Broadcast multiple XTensorVariables against each other."""

__props__ = ("exclude",)

def __init__(self, exclude: Sequence[str] = ()):
self.exclude = tuple(exclude)

def make_node(self, *inputs):
inputs = [as_xtensor(x) for x in inputs]

exclude = self.exclude
dims_and_shape = combine_dims_and_shape(inputs, exclude=exclude)

broadcast_dims = tuple(dims_and_shape.keys())
broadcast_shape = tuple(dims_and_shape.values())
dtype = upcast(*[x.type.dtype for x in inputs])

outputs = []
for x in inputs:
x_dims = x.type.dims
x_shape = x.type.shape
# The output has excluded dimensions in the order they appear in the op argument
excluded_dims = tuple(d for d in exclude if d in x_dims)
excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims)

output = xtensor(
dtype=dtype,
shape=broadcast_shape + excluded_shape,
dims=broadcast_dims + excluded_dims,
)
outputs.append(output)

return Apply(self, inputs, outputs)


def broadcast(
*args, exclude: str | Sequence[str] | None = None
) -> tuple[XTensorVariable, ...]:
"""Broadcast any number of XTensorVariables against each other.

Parameters
----------
*args : XTensorVariable
The tensors to broadcast against each other.
exclude : str or Sequence[str] or None, optional
"""
if not args:
return ()

if exclude is None:
exclude = ()
elif isinstance(exclude, str):
exclude = (exclude,)
elif not isinstance(exclude, Sequence):
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
# xarray broadcast always returns a tuple, even if there's only one tensor
return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore
9 changes: 9 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,15 @@ def dot(self, other, dim=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return px.math.dot(self, other, dim=dim)

def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)

def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
return self_bcast


class XTensorConstantSignature(TensorConstantSignature):
pass
Expand Down
14 changes: 12 additions & 2 deletions pytensor/xtensor/vectorization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from itertools import chain

import numpy as np
Expand All @@ -13,13 +14,22 @@
get_static_shape_from_size_variables,
)
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor


def combine_dims_and_shape(inputs):
def combine_dims_and_shape(
inputs: Sequence[XTensorVariable], exclude: Sequence[str] | None = None
) -> dict[str, int | None]:
"""Combine information of static dimensions and shapes from multiple xtensor inputs.

Exclude
"""
exclude_set: set[str] = set() if exclude is None else set(exclude)
dims_and_shape: dict[str, int | None] = {}
for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim in exclude_set:
continue
if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length
elif dim_length is not None:
Expand Down
Loading