Builtin differential testing mechanism in JAX #26567
Unanswered
PragmaTwice
asked this question in
Ideas
Replies: 1 comment 1 reply
-
Why not just run the same function under different backends one after another? I think that approach would be simpler than having multi-dispatch built-into jit. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, everyone. I recently discovered that if we develop a new PJRT plugin for JAX (different from the current XLA:CPU, XLA:CUDA, etc.), we will face some testing troubles, especially when the JAX program is extremely complex.
I thought that if there is a built-in differential testing mechanism in JAX: for example, when we want to test the correctness of a new GPU backend, we can do it by comparing the test results of the CPU and GPU (if ignore the precision difference). That is, when some JAX operators are called (such as
jax.numpy.add
, here we assume thatjax.jit
is not enabled), JAX (inpjit
orpxla
) can not only delegate the calculation to the new GPU backend for execution, but also delegate to CPU (or XLA:CUDA), and then compare the results at the end. Thus we can easily determine which intermediate result is not correct in the calculation (i.e. the PJRT plugin implementation has a problem), which should be meaningful when debugging the new PJRT plugin on large JAX programs.I made some modifications to
pjit.py
/pxla.py
in JAX to show the feasibility of these ideas, but I ran into some problems (got some device assignment errors). I was wondering if the JAX team has any thoughts or suggestions on this?Beta Was this translation helpful? Give feedback.
All reactions