Skip to content

Commit e53ff2c

Browse files
[Mosaic][Easy] - Wire up kernel names to MLIR dump
PiperOrigin-RevId: 699408419
1 parent b259fde commit e53ff2c

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def _lower_tpu_kernel(
278278
module: ir.Module,
279279
hardware_generation: int,
280280
target_shape: tuple[int, int],
281+
kernel_name: str | None = None,
281282
) -> ir.Module:
282283
"""Runs MLIR passes lowering the given module to an MLIR module.
283284
@@ -303,8 +304,7 @@ def _lower_tpu_kernel(
303304
tpu.register_dialect(ctx)
304305
mhlo.register_mhlo_dialect(ctx)
305306
mhlo.register_mhlo_passes()
306-
307-
dump_mlir(module, "original")
307+
dump_mlir(module, "original", kernel_name)
308308

309309
if _MOSAIC_ALLOW_HLO.value:
310310
# Run hlo dialect conversion: hlo -> linalg -> vector.
@@ -406,6 +406,7 @@ def _lower_mosaic_module_to_asm(
406406
*,
407407
backend: str,
408408
device_type: str | None,
409+
kernel_name: str | None,
409410
) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]:
410411
has_communication, has_custom_barrier = tpu.private_has_communication(
411412
module.operation
@@ -429,7 +430,7 @@ def _lower_mosaic_module_to_asm(
429430
hardware_generation = int(device_kind[len("TPU v")])
430431
target_shape = get_target_shape(hardware_generation)
431432
module = _lower_tpu_kernel(
432-
module, hardware_generation, target_shape=target_shape
433+
module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name,
433434
)
434435
needs_hlo_passes = False
435436
needs_layout_passes = False
@@ -504,6 +505,7 @@ def _lower_to_custom_call_config(
504505
collective_id: int | None,
505506
serialization_format: int | None,
506507
output_memory_spaces: tuple[MemorySpace | None, ...] | None = None,
508+
kernel_name: str | None = None,
507509
) -> CustomCallBackendConfig:
508510
lowered_module_asm, (
509511
has_communication,
@@ -514,6 +516,7 @@ def _lower_to_custom_call_config(
514516
module,
515517
backend=backend,
516518
device_type=device_type,
519+
kernel_name=kernel_name,
517520
)
518521
return _lowered_to_custom_call_config(
519522
lowered_module_asm,
@@ -613,6 +616,7 @@ def lower_module_to_custom_call(
613616
device_type=device_type,
614617
serialization_format=serialization_format,
615618
output_memory_spaces=output_memory_spaces,
619+
kernel_name=kernel_name,
616620
)
617621
return _tpu_custom_call_lowering(
618622
ctx,
@@ -654,6 +658,7 @@ def as_tpu_kernel(
654658
collective_id=collective_id,
655659
serialization_format=serialization_format,
656660
output_memory_spaces=output_memory_spaces,
661+
kernel_name=kernel_name,
657662
)
658663
return _as_jax_callable(
659664
config,
@@ -735,7 +740,7 @@ def apply_kernel(*args):
735740
return jax.jit(apply_kernel)
736741

737742

738-
def dump_mlir(module: ir.Module, name: str):
743+
def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None):
739744
"""A helper function to dump mosaic mlir module"""
740745
try:
741746
should_dump = FLAGS["xla_mosaic_dump_to"].value
@@ -744,6 +749,8 @@ def dump_mlir(module: ir.Module, name: str):
744749
if should_dump == "sponge":
745750
outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None)
746751
if outdir:
752+
if kernel_name:
753+
name = f"{kernel_name}-{name}"
747754
path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt")
748755
with open(path, "w") as f:
749756
f.write(str(module))

0 commit comments

Comments
 (0)