Skip to content
Discussion options

You must be logged in to vote

I think your question is ill-posed: within the computation of an outer jit-compiled function, internal jit-compiled functions no longer exist as units that can be benchmarked. As a simple example of this, let's consider these three functions:

from jax import jit
import jax.numpy as jnp

@jit
def func1(x):
  return x + 1

@jit
def func2(x):
  return x - 1

@jit
def func2_of_func1(x):
  return func2(func1(x))

We can look at the compiled HLO for the first two functions like this:

x = jnp.arange(5)
print(func1.lower(x).compile().as_text())
# HloModule jit_func1.14, entry_computation_layout={(s32[5]{0})->s32[5]{0}}

# %fused_computation (param_0: s32[5]) -> s32[5] {
#   %param_0 = s32[5]{0} pa…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@StoneT2000
Comment options

Comment options

You must be logged in to vote
1 reply
@StoneT2000
Comment options

Answer selected by StoneT2000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants