print optimized code #7068
-
Hi, I would like to see what jit and xla is able to do as optimization. See an example: def f(a, b, c):
return a * b * c
print(jax.make_jaxpr(jax.grad(f))(1., 2., 3.)) prints
I guess that the optimized code should look like
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 11 replies
-
There are two ways to dump the optimized HLO. One is to set the following environment variable: Another is the following undocumented and internal API (no promises of stability):
|
Beta Was this translation helpful? Give feedback.
-
In case people come across this question in searches, there is now an easier API for printing optimized computations, described at Ahead of time Lowering and compilation. For example: def f(a, b, c):
return a * b * c
print(jax.jit(jax.grad(f)).lower(1., 2., 3.).compile().as_text())
|
Beta Was this translation helpful? Give feedback.
There are two ways to dump the optimized HLO. One is to set the following environment variable:
XLA_FLAGS=--xla_dump_to=/tmp/somewhere
.Another is the following undocumented and internal API (no promises of stability):