Skip to content
Discussion options

You must be logged in to vote

The axes argument to jnp.transpose is required to be static, so it is not possible to vmap over this argument. This requirement of static transpose axes comes from XLA, so it is not something that can be easily modifed. You'd have to come up with some alternate way to transpose the array that can make use of dynamic inputs.

Here' an approach that might be sufficient for your needs:

def dynamic_transpose(x, axes):
  axes = jnp.asarray(axes)
  assert len(set(x.shape)) == 1
  assert axes.shape == (x.ndim,)
  indices = jnp.mgrid[tuple(slice(s) for s in x.shape)]
  indices = indices[axes]
  return x[tuple(indices[i] for i in range(indices.shape[0]))]

result_slow = jnp.array([x.transpose(p) for x

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mgbukov
Comment options

@jakevdp
Comment options

Answer selected by mgbukov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants