Skip to content
Discussion options

You must be logged in to vote

Within the loop function, all arguments will be traced. This means that s['n'] will be traced, and so when you write shape=(n,) you are requesting a dynamicly-shaped array, which is not currently supported in JAX.

To fix this, you'll need to use a static value for n. One way to do this would be to use the global n value; that is change this line:

s["key"], bin_array = random_binary_array(key2, s['n'])

to this:

s["key"], bin_array = random_binary_array(key2, n)

Of course this assumes that n is not changing from iteration to iteration; if it is, then you really do have dynamic array sizes and you'll have to find a different approach to your algorithm within JAX.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@yiminghwang
Comment options

@mattjj
Comment options

@yiminghwang
Comment options

Answer selected by yiminghwang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants