Skip to content

Commit c038109

Browse files
Add is_left_expand_dims and is_right_expand_dims attributes to DimShuffle
1 parent 04ddb46 commit c038109

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

pytensor/tensor/elemwise.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,12 @@ def __init__(self, input_broadcastable, new_order):
182182
self.drop = drop
183183

184184
input_ndim = len(input_broadcastable)
185-
186-
self.is_left_expand_dims = "x" in new_order and (
187-
input_ndim == 0 or new_order[:input_ndim] == list(range(input_ndim))
188-
)
189-
190-
self.is_right_expand_dims = "x" in new_order and (
185+
self.is_left_expand_dims = self.augment and (
191186
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
192187
)
188+
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
189+
range(input_ndim)
190+
)
193191

194192
if self.inplace:
195193
self.view_map = {0: [0]}

0 commit comments

Comments
 (0)