Skip to content
Discussion options

You must be logged in to vote

Maybe you could share more information about your real™ use case?

There are some (hopefully useful) general suggestions I can give now:

  1. .at[...].set(...) is functional, meaning you should write A = A.at[i, j].set(i * j)
  2. One should prefer approaches without explicit loop at all, such as
@partial(jit, static_argnums=(0,))
def f(shape: tuple[int, int]) -> tuple[Array, Array]:
  I, J = jnp.indices(shape)
  A = jnp.zeros(shape).at[I, J].set(I * J)
  return A

or even simpler for your toy example

A = jnp.fromfunction(jnp.multiply, (m, n))
  1. Jit compile your computation, and benchmark it carefully following this

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants