Skip to content
Discussion options

You must be logged in to vote

Thanks for the question! So there are a couple things going on here:

First, while numpy has special scalar types that are distinct from (and behave differently than) zero-dimensional arrays, JAX made the design choice early to not have separate scalar types. So any function that returns a scalar value will represent it as a zero-dimensional DeviceArray.

Second, this design choice aside, it sounds like you are trying to create an array whose shape depends on the content of another array. This is fine if you're doing it outside of JIT or other transformations, but inside transformations it is not possible. Why? JAX's compilation and transformation model depends on array shapes and data type…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by ozencgungor
Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@ozencgungor
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants