Skip to content
Discussion options

You must be logged in to vote

After investigating JAX source code a little bit further I found how to obtain the compiled (post-partitioning) HLO.

with device_mesh:
    modules = f.lower(A, x).compile().compiler_ir()
    for hlo in modules:
        print(hlo.to_string())

The resulting HLO looks like the listing below. Notice the all-reduce and fusion operations in the entry function.

HloModule pjit_dot.7

%add (x: f32[], y: f32[]) -> f32[] {
  %x = f32[] parameter(0)
  %y = f32[] parameter(1)
  ROOT %add = f32[] add(f32[] %x, f32[] %y)
}

%add.1 (x.1: f32[], y.1: f32[]) -> f32[] {
  %x.1 = f32[] parameter(0)
  %y.1 = f32[] parameter(1)
  ROOT %add.2 = f32[] add(f32[] %x.1, f32[] %y.1)
}

%fused_computation (param_0.1:…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by leiteg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant