-
Hey, I have been experimenting with the Consider the following scenario: a sharded matrix #!/usr/bin/env python3
import os
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import PartitionSpec
from jax.experimental.maps import Mesh
from jax.experimental.pjit import pjit
# Expose 4 devices to JAX/XLA
# ------------------------------------------------
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
jax.config.update('jax_platform_name', 'cpu')
# Create 2x2 device mesh
# ------------------------------------------------
devices = jax.devices()
device_array = np.asarray(devices).reshape(2, 2)
device_mesh = Mesh(device_array, ("x", "y"))
# Create matrix A and vector x
# ------------------------------------------------
A = jnp.eye(16)
x = jnp.arange(16).reshape((-1, 1))
# Transform `dot` operator using `pjit`
# ------------------------------------------------
spec = PartitionSpec("x", "y")
f = pjit(jnp.dot,
in_axis_resources=(spec, None),
out_axis_resources=None)
# Obtain HLO representation
# ------------------------------------------------
with device_mesh:
ir = f.lower(A, x).compiler_ir("hlo")
print(ir.as_hlo_text()) Executing this code produces the following output:
The HLO module above includes the correct sharding annotations for all the tensors but there are no communication primitives whatsover. I suspect that the GSPMD Partitioner has not run at this point yet. Is there a way to obtain the HLO module after this compiler pass has executed? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
After investigating JAX source code a little bit further I found how to obtain the compiled (post-partitioning) HLO. with device_mesh:
modules = f.lower(A, x).compile().compiler_ir()
for hlo in modules:
print(hlo.to_string()) The resulting HLO looks like the listing below. Notice the
|
Beta Was this translation helpful? Give feedback.
After investigating JAX source code a little bit further I found how to obtain the compiled (post-partitioning) HLO.
The resulting HLO looks like the listing below. Notice the
all-reduce
andfusion
operations in the entry function.