Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question. Currently the n argument to matrix_power must be a static integer, meaning that it's not possible to vmap over this argument, because mapped arguments are not static. In that sense, this is working as expected given the current implementation of jnp.linalg.matrix_power.

Unless that changes, the optimal approach here is probably to avoid the vmap and use a Python for-loop instead; i.e. something like this:

from jax import numpy as jnp
from jax import vmap

N = 2
M = 3
T_t = jnp.ones((N,M,M)) 
n = jnp.arange(1,N+1) 

result = jnp.stack([jnp.linalg.matrix_power(mat, int(pow))
                    for mat, pow in zip(T_t, n)])
print(result)
[[[1. 1. 1.]
  [1. 1. 1…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@jakevdp
Comment options

@rkmalaiya
Comment options

@jakevdp
Comment options

@rkmalaiya
Comment options

Answer selected by rkmalaiya
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants