Using custom data structure in JIT compiled code #19830
-
Hello! An algorithm I want to implement requires efficient removal, selection, and random choice. A pure python data structure that accomplishes this task was discussed here for example. Since JAX requires static shaped arrays, I've attempted to develop a class which accomplishes the same task using only JAX arrays. This is my first time using a custom data structure in JAX, and would appreciate any feedback on whether this will fail in any spectacular way.
I am currently running some tests to see if this code behaves as expected. Using the discussion here, I have
which appears to yield the correct result
Assuming the logic of the class is correct, may I use this class normally in my code moving forward? Are there any JAX specific quirks I'm not accounting for? For my use case, Suggestions are greatly appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi! This looks good, the only concerning thing I see is the use of in-place mutation in your class. For example, As long as you don't write code that expects mutation side-effects to persist across transformation boundaries, it should be fine. For example, |
Beta Was this translation helpful? Give feedback.
Hi! This looks good, the only concerning thing I see is the use of in-place mutation in your class. For example,
res.kill(0)
will result inres
being modified, butjax.jit(res.kill)(0)
will not, because JAX requires pure functions.As long as you don't write code that expects mutation side-effects to persist across transformation boundaries, it should be fine. For example,
res.kill
is impure and therefore problematic, but yourbody_fun
as defined is non-problematic despite callingval.kill
, because the mutations are reflected in the return value of the function.