Skip to content

Commit 1024798

Browse files
committed
Cleaning up squeeze
1 parent 9c1a0b7 commit 1024798

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,28 +162,18 @@ def local_expand_dims_reshape(fgraph, node):
162162
@register_xcanonicalize
163163
@node_rewriter([Squeeze])
164164
def local_squeeze_reshape(fgraph, node):
165-
"""Rewrite rule to convert Squeeze to pytensor.tensor.squeeze."""
166-
if not isinstance(node.op, Squeeze):
167-
return False
168-
169-
[x] = node.inputs
170-
in_dims = x.type.dims
165+
"""Rewrite Squeeze to tensor.squeeze."""
166+
x = node.inputs[0]
171167
dim = node.op.dims
172168

173-
# Determine which axes to squeeze
174169
if dim is None:
175-
# Infer axes by comparing input and output dims
176-
out_dims = node.outputs[0].type.dims
177-
axes_to_squeeze = tuple(i for i, d in enumerate(in_dims) if d not in out_dims)
178-
else:
179-
dims_to_remove = [dim] if isinstance(dim, str) else dim
180-
axes_to_squeeze = tuple(in_dims.index(d) for d in dims_to_remove)
181-
182-
# Nothing to squeeze? Just return input unchanged
183-
if not axes_to_squeeze:
184170
return [x]
185171

186172
x_tensor = tensor_from_xtensor(x)
173+
x_dims = x.type.dims
174+
dims_to_remove = [dim] if isinstance(dim, str) else dim
175+
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
187176
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
177+
188178
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
189179
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -397,19 +397,21 @@ class Squeeze(XOp):
397397

398398
__props__ = ("dims",)
399399

400-
def __init__(self, dims):
401-
self.dims = dims
400+
def __init__(self, dim):
401+
self.dims = tuple(sorted(set(dim)))
402402

403403
def make_node(self, x):
404404
x = as_xtensor(x)
405405

406406
# Validate that dims exist and are size-1 if statically known
407407
dims_to_remove = []
408+
x_dims = x.type.dims
409+
x_shape = x.type.shape
408410
for d in self.dims:
409-
if d not in x.type.dims:
411+
if d not in x_dims:
410412
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
411-
idx = x.type.dims.index(d)
412-
dim_size = x.type.shape[idx]
413+
idx = x_dims.index(d)
414+
dim_size = x_shape[idx]
413415
if dim_size is not None and dim_size != 1:
414416
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
415417
dims_to_remove.append(idx)
@@ -454,9 +456,6 @@ def squeeze(x, dim=None):
454456
else:
455457
dims = tuple(dim)
456458

457-
# Normalize: deduplicate and sort
458-
dims = tuple(sorted(set(dims)))
459-
460459
if not dims:
461460
return x # no-op if nothing to squeeze
462461

tests/xtensor/test_shape.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_squeeze_explicit_dims():
448448
# Redundant dimensions
449449
y3c = squeeze(x3, ["b", "b"])
450450
fn3c = xr_function([x3], y3c)
451-
xr_assert_allclose(fn3c(x3_test), x3_test.squeeze("b"))
451+
xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"]))
452452

453453
# Empty list = no-op
454454
y3d = squeeze(x3, [])
@@ -495,6 +495,22 @@ def test_squeeze_implicit_dims():
495495
x5_test = xr_arange_like(x5)
496496
xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test)
497497

498+
"""
499+
This test documents that we intentionally don't squeeze dimensions with symbolic shapes
500+
(static_shape=None) even when they are 1 at runtime, while xarray does squeeze them.
501+
"""
502+
# Create a tensor with a symbolic dimension that will be 1 at runtime
503+
x = xtensor("x", dims=("a", "b", "c")) # shape unknown
504+
y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions
505+
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3)))
506+
fn = xr_function([x], y)
507+
res = fn(x_test)
508+
509+
# Our implementation should not squeeze the symbolic dimension
510+
assert "b" in res.dims
511+
# While xarray would squeeze it
512+
assert "b" not in x_test.squeeze().dims
513+
498514

499515
def test_squeeze_errors():
500516
"""Test error cases for squeeze."""

0 commit comments

Comments
 (0)