You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is it possible/expected that Jax programs that use the PRNG would give different results when the same seed is used, all software dependencies are identical (run within a container), but different runs happen on different generations of Intel CPUs? If the answer is yes, are there ways to modify this behavior (e.g., XLA Flags)?
Please note this is a follow up question based on an open discussion in the NumPyro forum. NumPyro uses Jax. There is a post from another NumPyro user where numerical results differed on different machines. I've seen this behavior as well when running time-series models written in NumPyro in a distributed computing system where we set a seed and use a container. I consistently see [usually small] differences in forecast posteriors for the same timeseries when the models use the same seed and run on machines with different Intel CPUs. However, I get identical results when models run with the same seed, timeseries, and on the same Intel CPU type. In my use case, I sometimes need to exactly control for the model behavior because the models run in larger numerical experiments where other differences need to be evaluated separately, but find in this setting I cannot exactly control for the model with dependencies on Jax without also limiting what machines (a single CPU type) are used. I haven't been able to track down much information about whether this is expected behavior - but have seen a lot of discussions around this sort of behavior when broadly comparing Jax programs run on CPUs versus GPUs.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Is it possible/expected that Jax programs that use the PRNG would give different results when the same seed is used, all software dependencies are identical (run within a container), but different runs happen on different generations of Intel CPUs? If the answer is yes, are there ways to modify this behavior (e.g., XLA Flags)?
Please note this is a follow up question based on an open discussion in the NumPyro forum. NumPyro uses Jax. There is a post from another NumPyro user where numerical results differed on different machines. I've seen this behavior as well when running time-series models written in NumPyro in a distributed computing system where we set a seed and use a container. I consistently see [usually small] differences in forecast posteriors for the same timeseries when the models use the same seed and run on machines with different Intel CPUs. However, I get identical results when models run with the same seed, timeseries, and on the same Intel CPU type. In my use case, I sometimes need to exactly control for the model behavior because the models run in larger numerical experiments where other differences need to be evaluated separately, but find in this setting I cannot exactly control for the model with dependencies on Jax without also limiting what machines (a single CPU type) are used. I haven't been able to track down much information about whether this is expected behavior - but have seen a lot of discussions around this sort of behavior when broadly comparing Jax programs run on CPUs versus GPUs.
The related NumPyro forum question is here for reference: "Different HMC results between devices, for same seed"
Thanks for any insights you can share.
Beta Was this translation helpful? Give feedback.
All reactions