Any known reasons for differing numerical outputs after 0.3.0 update? #9556
-
Hi, Apologies in advance for the vagueness of this question, but are there any known/expected situations in which the numerical output of "core" JAX operations should be different between 0.3.0 and the previous release? (By "core" I mean to imply that the code I'm looking at doesn't use any newer JAX functionality, its all built from simple Long story short, I have some JAX code that, after upgrading to 0.3.0, is outputting substantially different numerical values than when its run on the previous release. The problem occurs whether I'm on jaxlib 0.3.0 or 0.1.76, so I think whatever is causing this difference is coming from a difference between jax 0.2.28 -> 0.3.0. (Basically I can switch between which release is installed and observe the numerical difference, with relative errors on the order of I need to keep diving in to isolate where the difference is occurring (at which point I can share more detailed info) - but at this point I wanted to check in here to see if anything like this is known and/or coming up elsewhere. Thanks! Dan |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
Despite the version number bump, 0.3 didn't actually change much in the way of code (it's mostly a new version numbering scheme). So I think any numerical difference is unintended. I think we'd need a repro to say more. |
Beta Was this translation helpful? Give feedback.
-
Okay based on a suggestion by a colleague I've identified which commit in the The test I'm running "passes" (produces expected results which are consistent with previous commits) up to and including commit
but starts failing on commit
I still need to identify in exactly what part of our code this ultimately gets used. |
Beta Was this translation helpful? Give feedback.
Okay based on a suggestion by a colleague I've identified which commit in the
0.3.0
release introduces the difference (installed from source and binary searched the commit list in jax-v0.2.28...jax-v0.3.0 ).The test I'm running "passes" (produces expected results which are consistent with previous commits) up to and including commit
but starts failing on commit
I still need to identify in exactly what part of our code this ultimately gets used.