Functions of Index Arrays -- Efficient Calculation #14770
-
Hi, some I'm currently doing some work to compute Christoffel symbols using I have a function to compute the metric tensor Constructing My first attempt at solving this was simple nested for loops: def fk_christoffel(x: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
dgdx = jax.jacobian(metric)(x)
christ_fk = jnp.zeros_like(dgdx)
for k, i, j in it.product(*map(range, dgdx.shape)):
christ_fk = christ_fk.at[k, i, j].set(dgdx[k, i, j] + dgdx[k, j, i] - dgdx[i, j, k])
christ_fk /= 2.0
return christ_fk I considered writing this explicitly with I then considered first generating all of the possible permutations and using @ft.lru_cache(maxsize=16)
def index_permutations_of_shape(shape: tuple[int]) -> list[tuple[int]]:
return jnp.array([p for p in it.product(*map(range, shape))])
def fk_christoffel(x: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
dgdx = jax.jacobian(metric)(x)
def set_value(kij: jnp.ndarray, _christ_fk: jnp.ndarray) -> jnp.ndarray:
k, i, j = kij
return _christ_fk.at[k, i, j].set(dgdx[k, i, j] + dgdx[k, j, i] - dgdx[i, j, k])
christ_fk = jnp.zeros_like(dgdx)
perms = index_permutations_of_shape(dgdx.shape)
christ_fk = jax.vmap(set_value, in_axes=(0, None))(perms, christ_fk).sum(axis=0)
christ_fk /= 2.0
return christ_fk While this seems to work okay, the memory requirements for large shapes will be huge - does I would like to use something like |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I believe I have found an efficient implementation which uses def fk_christoffel(x: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
dgdx = jax.jacobian(metric)(x)
def get_value(k, i, j) -> jnp.ndarray:
return 0.5 * (dgdx[k, i, j] + dgdx[k, j, i] - dgdx[i, j, k])
return jnp.vectorize(get_value)(*jnp.indices(dgdx.shape)) I would appreciate any feedback on if there's a better way to do this still -- if not, please feel free to close. Thank you! |
Beta Was this translation helpful? Give feedback.
I believe I have found an efficient implementation which uses
jnp.vectorize
:I would appreciate any feedback on if there's a better way to do this still -- if not, please feel free to close.
Thank you!