Skip to content
Discussion options

You must be logged in to vote

I believe I have found an efficient implementation which uses jnp.vectorize:

def fk_christoffel(x: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
      
    dgdx = jax.jacobian(metric)(x)
    
    def get_value(k, i, j) -> jnp.ndarray:
        return 0.5 * (dgdx[k, i, j] + dgdx[k, j, i] - dgdx[i, j, k])

    return jnp.vectorize(get_value)(*jnp.indices(dgdx.shape))

I would appreciate any feedback on if there's a better way to do this still -- if not, please feel free to close.

Thank you!

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mattjj
Comment options

@danielkelshaw
Comment options

Answer selected by danielkelshaw
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants