-
I'm running a large computation with scans, with Before the main How do I avoid this? I'm doing lots of sampling, so it seems unavoidable to call Also, what are best practices for keeping tracking of keys and reducing key allocations? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Can you make use of |
Beta Was this translation helpful? Give feedback.
Can you make use of
jax.random.fold_in
within the scan to avoid upfront allocation via split? For example, if you increment an integer counter between each scan iteration, you can fold that in to a fixed "root" scan key to produce a local subkey for each scan iteration.