Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question. scan can only be used to scan over JAX arrays, not Python lists. Since JAX does not have support for ragged arrays, it means that your use-case cannot be implemented via scan.

For what you have in mind, I'd suggest a simple Python loop. for example:

import jax
from jax import numpy as jnp

aa = [jnp.zeros((5,5,2)), jnp.zeros((5,5,3))]
bb = [jnp.zeros((2,)), jnp.zeros((3,))]

result = [jnp.einsum('ijk,k->ij', a, b) for a, b in zip(aa, bb)]

This will be compatible with both forward and reverse autodiff. If your lists are large, then wrapping the outer function in jit will likely lead to long compile times, because the for loop is flattened into a long sequence …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@fishjojo
Comment options

Answer selected by fishjojo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants