Skip to content

Commit 322f6ee

Browse files
Implement Dimshuffle using expand_dims/squeeze
1 parent b66d859 commit 322f6ee

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

pytensor/link/jax/dispatch/elemwise.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,8 @@ def jax_funcify_DimShuffle(op, **kwargs):
7474
def dimshuffle(x):
7575
res = jnp.transpose(x, op.transposition)
7676

77-
shape = list(res.shape[: len(op.shuffle)])
78-
79-
for augm in op.augment:
80-
shape.insert(augm, 1)
81-
82-
res = jnp.reshape(res, shape)
77+
res = jax.lax.expand_dims(res, op.augment)
78+
res = jax.lax.squeeze(res, op.drop)
8379

8480
if not op.inplace:
8581
res = jnp.copy(res)

0 commit comments

Comments
 (0)