What does jax.jit knows about arithmetic? #7730
-
Hi all, I love how Playing with printing optimized code by jit I observed that:
code: def f(x, y):
return x + 0.0, 1.0 * y, jnp.tanh(x), jnp.tanh(x), (x * y), (y * x), (y - y)
ENTRY (parameter: f32[], parameter: f32[]) -> (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[], f32[]) {
tmp_0 = parameter(0)
tmp_1 = copy(tmp_0) # x + 0.0
tmp_2 = parameter(1)
tmp_3 = copy(tmp_2) # 1.0 * y
tmp_4 = tanh(tmp_0)
tmp_5 = copy(tmp_4) # only one call to tanh
tmp_6 = multiply(tmp_0, tmp_2)
tmp_7 = multiply(tmp_2, tmp_0) # no copy here :cry:
tmp_8 = subtract(tmp_2, tmp_2) # no zeros here :cry:
ROOT tmp_9 = tuple(tmp_1, tmp_3, tmp_4, tmp_5, ...(+3))
} Are those simplifications still performed somehow at another stage? (for instance I know that floating point rounding is affected by the operations order but the same happen with integers) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I don't know of any documentation that summarizes the algebraic simplification rules implemented by XLA, but you can see it in the source code. For example, I believe this is the relevant implementation for XLA on CPU: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/algebraic_simplifier.cc The header file may be easier to skim through, and has some useful comments: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/algebraic_simplifier.h For example, here we see that |
Beta Was this translation helpful? Give feedback.
I don't know of any documentation that summarizes the algebraic simplification rules implemented by XLA, but you can see it in the source code. For example, I believe this is the relevant implementation for XLA on CPU: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/algebraic_simplifier.cc
The header file may be easier to skim through, and has some useful comments: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/algebraic_simplifier.h
For example, here we see that
A - A
is simplified to0
only for integerA
.