Skip to content
Discussion options

You must be logged in to vote

There are two ways to dump the optimized HLO. One is to set the following environment variable: XLA_FLAGS=--xla_dump_to=/tmp/somewhere.

Another is the following undocumented and internal API (no promises of stability):

In [34]: import jax

In [35]: def f(x): return jnp.sin(x)

In [36]: c = jax.xla_computation(f)(1.)

In [37]: backend = jax.lib.xla_bridge.get_backend()
    ...: e = backend.compile(c)
    ...: print(e.hlo_modules()[0].to_string())
HloModule xla_computation_f__1.5

ENTRY %xla_computation_f__1.5 (parameter.1: f64[]) -> (f64[]) {
  %parameter.1 = f64[] parameter(0)
  %sine.3 = f64[] sine(f64[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file=…

Replies: 2 comments 11 replies

Comment options

You must be logged in to vote
6 replies
@soraros
Comment options

@mariogeiger
Comment options

@awf
Comment options

@jecampagne
Comment options

@hawkinsp
Comment options

Answer selected by mariogeiger
Comment options

You must be logged in to vote
5 replies
@joaospinto
Comment options

@jakevdp
Comment options

@joaospinto
Comment options

@jakevdp
Comment options

@joaospinto
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
7 participants