Skip to content
Discussion options

You must be logged in to vote

In internal JAX code, I often use a pattern that looks something like this:

def f(x):
  out_shape = jnp.shape(x)
  x = jnp.atleast_1d(x)
  res = x**3
  return res.reshape(out_shape)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants