@@ -278,6 +278,7 @@ def _lower_tpu_kernel(
278278 module : ir .Module ,
279279 hardware_generation : int ,
280280 target_shape : tuple [int , int ],
281+ kernel_name : str | None = None ,
281282) -> ir .Module :
282283 """Runs MLIR passes lowering the given module to an MLIR module.
283284
@@ -303,8 +304,7 @@ def _lower_tpu_kernel(
303304 tpu .register_dialect (ctx )
304305 mhlo .register_mhlo_dialect (ctx )
305306 mhlo .register_mhlo_passes ()
306-
307- dump_mlir (module , "original" )
307+ dump_mlir (module , "original" , kernel_name )
308308
309309 if _MOSAIC_ALLOW_HLO .value :
310310 # Run hlo dialect conversion: hlo -> linalg -> vector.
@@ -406,6 +406,7 @@ def _lower_mosaic_module_to_asm(
406406 * ,
407407 backend : str ,
408408 device_type : str | None ,
409+ kernel_name : str | None ,
409410) -> tuple [ir .Module , tuple [bool , bool , bool , bool ]]:
410411 has_communication , has_custom_barrier = tpu .private_has_communication (
411412 module .operation
@@ -429,7 +430,7 @@ def _lower_mosaic_module_to_asm(
429430 hardware_generation = int (device_kind [len ("TPU v" )])
430431 target_shape = get_target_shape (hardware_generation )
431432 module = _lower_tpu_kernel (
432- module , hardware_generation , target_shape = target_shape
433+ module , hardware_generation , target_shape = target_shape , kernel_name = kernel_name ,
433434 )
434435 needs_hlo_passes = False
435436 needs_layout_passes = False
@@ -504,6 +505,7 @@ def _lower_to_custom_call_config(
504505 collective_id : int | None ,
505506 serialization_format : int | None ,
506507 output_memory_spaces : tuple [MemorySpace | None , ...] | None = None ,
508+ kernel_name : str | None = None ,
507509) -> CustomCallBackendConfig :
508510 lowered_module_asm , (
509511 has_communication ,
@@ -514,6 +516,7 @@ def _lower_to_custom_call_config(
514516 module ,
515517 backend = backend ,
516518 device_type = device_type ,
519+ kernel_name = kernel_name ,
517520 )
518521 return _lowered_to_custom_call_config (
519522 lowered_module_asm ,
@@ -613,6 +616,7 @@ def lower_module_to_custom_call(
613616 device_type = device_type ,
614617 serialization_format = serialization_format ,
615618 output_memory_spaces = output_memory_spaces ,
619+ kernel_name = kernel_name ,
616620 )
617621 return _tpu_custom_call_lowering (
618622 ctx ,
@@ -654,6 +658,7 @@ def as_tpu_kernel(
654658 collective_id = collective_id ,
655659 serialization_format = serialization_format ,
656660 output_memory_spaces = output_memory_spaces ,
661+ kernel_name = kernel_name ,
657662 )
658663 return _as_jax_callable (
659664 config ,
@@ -735,7 +740,7 @@ def apply_kernel(*args):
735740 return jax .jit (apply_kernel )
736741
737742
738- def dump_mlir (module : ir .Module , name : str ):
743+ def dump_mlir (module : ir .Module , name : str , kernel_name : str | None = None ):
739744 """A helper function to dump mosaic mlir module"""
740745 try :
741746 should_dump = FLAGS ["xla_mosaic_dump_to" ].value
@@ -744,6 +749,8 @@ def dump_mlir(module: ir.Module, name: str):
744749 if should_dump == "sponge" :
745750 outdir = os .environ .get ("TEST_UNDECLARED_OUTPUTS_DIR" , None )
746751 if outdir :
752+ if kernel_name :
753+ name = f"{ kernel_name } -{ name } "
747754 path = os .path .join (outdir , f"{ time .time_ns ()} -mosaic-dump-{ name } -py.txt" )
748755 with open (path , "w" ) as f :
749756 f .write (str (module ))
0 commit comments