Skip to content

Commit 3202c4c

Browse files
Update pytensor/xtensor/rewriting/shape.py
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 98d297e commit 3202c4c

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,9 @@ def lower_transpose(fgraph, node):
124124
def local_squeeze_reshape(fgraph, node):
125125
"""Rewrite Squeeze to tensor.squeeze."""
126126
x = node.inputs[0]
127-
dim = node.op.dims
128-
129127
x_tensor = tensor_from_xtensor(x)
130128
x_dims = x.type.dims
131-
dims_to_remove = [dim] if isinstance(dim, str) else dim
129+
dims_to_remove = node.op.dims
132130
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
133131
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
134132

0 commit comments

Comments
 (0)