Skip to content
Discussion options

You must be logged in to vote

Without more details it's hard to say how the compiler will optimize any particular sequence of operations, but in general it should. The most definitive way to check is to print the compiled HLO to see. Here's a simple example:

import jax
import jax.numpy as jnp

def f(x):
  y = jnp.sin(x)
  z = jnp.sin(x)
  return y + z

print(jax.jit(f).lower(1.0).compile().as_text())
HloModule jit_f, entry_computation_layout={(f32[])->f32[]}

%fused_computation (param_0.1: f32[]) -> f32[] {
  %param_0.1 = f32[] parameter(0)
  %sine.0 = f32[] sine(f32[] %param_0.1), metadata={op_name="jit(f)/jit(main)/sin" source_file="<ipython-input-34-90031d99c954>" source_line=5}
  ROOT %add.0 = f32[] add(f32[] %sin…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@mchagneux
Comment options

Answer selected by mchagneux
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants