-
Hi all! I've written a few interpreters using JAX now, and typically I follow the following pattern:
I've never had to write my own Is there a downside to staging first? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Others might chime in, but in general the easiest approach to implement a custom transformation is using what we often call initial style: staging out to a jaxpr and using a jaxpr interpreter. The alternative is final style, implementing the transformation on-the-fly using a custom tracer. This is harder, but the advantage is often easier debuggability because the transformation happens within the context of the traced code, rather than as a post-processing step after it is traced. It also is more flexible, e.g. it allows for Python control flow to depend on traced data in some cases. But final-style transformations are somewhat more difficult to implement. If you want to see a side-by-side comparison of the two approaches, one place that exists in the JAX codebase is in the two implementations of the There's also some discussion of these two approaches in autodidax. |
Beta Was this translation helpful? Give feedback.
Others might chime in, but in general the easiest approach to implement a custom transformation is using what we often call initial style: staging out to a jaxpr and using a jaxpr interpreter. The alternative is final style, implementing the transformation on-the-fly using a custom tracer. This is harder, but the advantage is often easier debuggability because the transformation happens within the context of the traced code, rather than as a post-processing step after it is traced. It also is more flexible, e.g. it allows for Python control flow to depend on traced data in some cases. But final-style transformations are somewhat more difficult to implement.
If you want to see a side-by-si…