-
I have a batch of Here's my attempt:
This throws the error, as shown below. I think it occurs because After trying various things (like
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The Here' an approach that might be sufficient for your needs: def dynamic_transpose(x, axes):
axes = jnp.asarray(axes)
assert len(set(x.shape)) == 1
assert axes.shape == (x.ndim,)
indices = jnp.mgrid[tuple(slice(s) for s in x.shape)]
indices = indices[axes]
return x[tuple(indices[i] for i in range(indices.shape[0]))]
result_slow = jnp.array([x.transpose(p) for x, p in zip(data, transposed_axes)])
result_vmap = vmap(dynamic_transpose)(data, transposed_axes)
np.testing.assert_allclose(result_slow, result_vmap) |
Beta Was this translation helpful? Give feedback.
The
axes
argument tojnp.transpose
is required to be static, so it is not possible tovmap
over this argument. This requirement of static transpose axes comes from XLA, so it is not something that can be easily modifed. You'd have to come up with some alternate way to transpose the array that can make use of dynamic inputs.Here' an approach that might be sufficient for your needs: