Skip to content

Commit 3b7b973

Browse files
committed
Fixing ExpandDims rewrite
1 parent 802a536 commit 3b7b973

File tree

3 files changed

+45
-53
lines changed

3 files changed

+45
-53
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.raise_op import Assert
32
from pytensor.tensor import (
43
broadcast_to,
5-
get_scalar_constant_value,
6-
gt,
4+
expand_dims,
75
join,
86
moveaxis,
97
specify_shape,
108
squeeze,
119
)
12-
from pytensor.tensor import (
13-
shape as tensor_shape,
14-
)
1510
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
1611
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
1712
from pytensor.xtensor.shape import (
@@ -144,27 +139,30 @@ def local_squeeze_reshape(fgraph, node):
144139
@register_lower_xtensor
145140
@node_rewriter([ExpandDims])
146141
def local_expand_dims_reshape(fgraph, node):
147-
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify shape."""
142+
"""Rewrite ExpandDims using tensor operations."""
148143
x, size = node.inputs
149144
out = node.outputs[0]
150-
# Lower to tensor.expand_dims(x, axis=0)
151-
from pytensor.tensor import expand_dims as tensor_expand_dims
152-
153-
expanded = tensor_expand_dims(tensor_from_xtensor(x), 0)
154-
# Optionally broadcast to the correct shape if size is not 1
155-
from pytensor.tensor import broadcast_to
156-
157-
# Ensure size is positive
158-
expanded = Assert(msg="size must be positive")(expanded, gt(size, 0))
159-
# If size is not 1, broadcast
160-
try:
161-
static_size = get_scalar_constant_value(size)
162-
except Exception:
163-
static_size = None
164-
if static_size is not None and static_size == 1:
165-
result = expanded
145+
146+
# Convert inputs to tensors
147+
x_tensor = tensor_from_xtensor(x)
148+
size_tensor = tensor_from_xtensor(size)
149+
150+
# Get the new dimension name and position
151+
new_axis = 0 # Always insert at front
152+
153+
# Use tensor operations
154+
if out.type.shape[0] == 1:
155+
# Simple case: just expand with size 1
156+
result_tensor = expand_dims(x_tensor, new_axis)
166157
else:
167-
# Broadcast to (size, ...)
168-
new_shape = (size,) + tuple(tensor_shape(expanded))[1:]
169-
result = broadcast_to(expanded, new_shape)
170-
return [xtensor_from_tensor(result, out.type.dims)]
158+
# First expand with size 1
159+
expanded = expand_dims(x_tensor, new_axis)
160+
# Then broadcast to the requested size
161+
result_tensor = broadcast_to(expanded, (size_tensor, *x_tensor.shape))
162+
163+
# Preserve static shape information
164+
result_tensor = specify_shape(result_tensor, out.type.shape)
165+
166+
# Convert result back to xtensor
167+
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
168+
return [result]

pytensor/xtensor/shape.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from pytensor.graph import Apply
9+
from pytensor.graph.basic import Constant
910
from pytensor.scalar import discrete_dtypes, upcast
1011
from pytensor.tensor import as_tensor, get_scalar_constant_value
1112
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -391,43 +392,47 @@ class ExpandDims(XOp):
391392
__props__ = ("dim",)
392393

393394
def __init__(self, dim):
395+
if not isinstance(dim, str):
396+
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")
397+
394398
self.dim = dim
395399

396400
def make_node(self, x, size):
397401
x = as_xtensor(x)
398402

399-
if not isinstance(self.dim, str):
400-
raise TypeError(f"`dim` must be a string or None, got: {type(self.dim)}")
401-
402403
if self.dim in x.type.dims:
403404
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
404-
if isinstance(size, int | np.integer):
405-
if size <= 0:
406-
raise ValueError(f"size must be positive, got: {size}")
407-
elif not (
408-
hasattr(size, "ndim")
409-
and getattr(size, "ndim", None) == 0 # symbolic scalar
405+
406+
# Check if size is a valid type before converting
407+
if not (
408+
isinstance(size, int | np.integer)
409+
or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0)
410410
):
411411
raise TypeError(
412412
f"size must be an int or scalar variable, got: {type(size)}"
413413
)
414414

415-
# Convert size to tensor
416-
size = as_tensor(size, ndim=0)
417-
418-
# Insert new dim at front
419-
new_dims = (self.dim, *x.type.dims)
420-
421415
# Determine shape
422416
try:
423417
static_size = get_scalar_constant_value(size)
424418
except NotScalarConstantError:
425419
static_size = None
420+
426421
if static_size is not None:
427422
new_shape = (int(static_size), *x.type.shape)
428423
else:
429424
new_shape = (None, *x.type.shape) # symbolic size
430425

426+
# Convert size to tensor
427+
size = as_xtensor(size, dims=())
428+
429+
# Check if size is a constant and validate it
430+
if isinstance(size, Constant) and size.data < 0:
431+
raise ValueError(f"size must be 0 or positive, got: {size.data}")
432+
433+
# Insert new dim at front
434+
new_dims = (self.dim, *x.type.dims)
435+
431436
out = xtensor(
432437
dtype=x.type.dtype,
433438
shape=new_shape,

tests/xtensor/test_shape.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,6 @@ def test_expand_dims_errors():
540540
with pytest.raises(ValueError, match="already exists"):
541541
expand_dims(y, "city")
542542

543-
# Size = 0 is invalid
544-
with pytest.raises(ValueError, match="size must be.*positive"):
545-
expand_dims(x, "batch", size=0)
546-
547543
# Invalid dim type
548544
with pytest.raises(TypeError, match="Invalid type for `dim`"):
549545
expand_dims(x, 123)
@@ -557,13 +553,6 @@ def test_expand_dims_errors():
557553
with pytest.raises(ValueError, match="already exists"):
558554
expand_dims(y, "new")
559555

560-
# Symbolic size with invalid runtime value
561-
size_sym = scalar("size_sym", dtype="int64")
562-
y = expand_dims(x, "batch", size=size_sym)
563-
fn = xr_function([x, size_sym], y, on_unused_input="ignore")
564-
with pytest.raises(Exception):
565-
fn(xr_arange_like(x), 0)
566-
567556

568557
def test_expand_dims_multiple():
569558
"""Test expanding multiple dimensions at once using a list of strings."""

0 commit comments

Comments
 (0)