-
Hi All, I am trying to implement the matrix_power function for 'n' participants. The matrix_power function raises a square matrix to a whole number. e.g., it will raise the MxM matrix to 5. I have an MxM matrix for N participants. So, I have a matrix of NxMxM. I need to iterate over the 0th dim (N) and raise the MxM matrix to a whole number which is different for each participant. I assumed vmap might be a great way to do this, but I keep getting below error. The error is because, the matrix_power function expects a scaler but instead is getting a Traced datatype. I thought about using a simple for loop, but I am building a likelihood function that I will then give to numpyro for MCMC inference. Hence, I thought that vmap may better facilitate autodiff and inference performance. I could be wrong, please advise if there's a way to fix the below error or should I just try for-loop? Code:
Error:
The code works if I just execute the matrix_power function for a single value of 'power' parameter n Code:
Also, below code works if I just manually select an item from the 'n' array. But, I need to iterate through all the values of 'n'.
Thanks for the help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Hi - thanks for the question. Currently the Unless that changes, the optimal approach here is probably to avoid the from jax import numpy as jnp
from jax import vmap
N = 2
M = 3
T_t = jnp.ones((N,M,M))
n = jnp.arange(1,N+1)
result = jnp.stack([jnp.linalg.matrix_power(mat, int(pow))
for mat, pow in zip(T_t, n)])
print(result)
|
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. Currently the
n
argument tomatrix_power
must be a static integer, meaning that it's not possible tovmap
over this argument, because mapped arguments are not static. In that sense, this is working as expected given the current implementation ofjnp.linalg.matrix_power
.Unless that changes, the optimal approach here is probably to avoid the
vmap
and use a Pythonfor
-loop instead; i.e. something like this: