We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 98d297e commit 3202c4cCopy full SHA for 3202c4c
pytensor/xtensor/rewriting/shape.py
@@ -124,11 +124,9 @@ def lower_transpose(fgraph, node):
124
def local_squeeze_reshape(fgraph, node):
125
"""Rewrite Squeeze to tensor.squeeze."""
126
x = node.inputs[0]
127
- dim = node.op.dims
128
-
129
x_tensor = tensor_from_xtensor(x)
130
x_dims = x.type.dims
131
- dims_to_remove = [dim] if isinstance(dim, str) else dim
+ dims_to_remove = node.op.dims
132
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
133
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
134
0 commit comments