-
In the compiled function, the number of flops required for from jax import jit
import jax.numpy as jnp
x = jnp.ones((1000, 1000))
f = jit(lambda x: jnp.sum(x) * 5)
print(f.lower(x).compile().cost_analysis())
g = jit(lambda x: jnp.sum(x * 5))
print(g.lower(x).compile().cost_analysis()) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
I mentioned this in #31257 as well, but I suspect this is a deliberately skipped optimization, because there are cases where changing the order of operations in this way would lead to overflow. Here's a simple example: In [1]: import jax.numpy as jnp
In [2]: a = jnp.float32(1E-4)
In [3]: x = 1E37 * jnp.arange(10)
In [4]: (a * x).sum()
Out[4]: Array(4.5e+34, dtype=float32)
In [5]: a * x.sum()
Out[5]: Array(inf, dtype=float32) If the compiler were to automatically rewrite Note that the source of truth for these kinds of compiler decisions is not in JAX, but rather in https://github.com/openxla/xla – so folks at that repository might know more about this type of optimization, and whether there might be compiler flags that would enable it. |
Beta Was this translation helpful? Give feedback.
I think you're misunderstanding me: the goal of the compiler is not to avoid overflow, or improve floating point error accumulation; the goal of the compiler is to optimize code without significantly affecting the numerics as expressed in the original program. If the original program does not overflow, the compiled program should maintain that property. If the original program does overflow, the compiled program should maintain that property.
Of course, floating point math being what it is, you can never guarantee exact bitwise equivalence before and after a compiler rewrite, but what you can do is avoid certain optimizations that have been found to be problematic in practice (and XLA has…