Skip to content

Commit d034680

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Always annotate block initialization in the profiles
This helps establish a shared timeline between different warpgroups and shows how expensive it really was. PiperOrigin-RevId: 703105898
1 parent d5ead57 commit d034680

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -806,16 +806,6 @@ def _launch(
806806
)
807807
)
808808

809-
smem_ref_tree = _construct_smem_reftree(
810-
cluster, dynamic_smem, smem_buffers
811-
)
812-
# TODO(apaszke): Skip the following if no barriers were initialized.
813-
nvvm.fence_mbarrier_init()
814-
if math.prod(cluster) != 1:
815-
nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get())
816-
nvvm.cluster_wait(aligned=ir.UnitAttr.get())
817-
gpu.barrier()
818-
819809
if profiler_spec:
820810
prof_smem = memref.view(
821811
ir.MemRefType.get(
@@ -832,7 +822,19 @@ def _launch(
832822

833823
ptr_ty = ir.Type.parse("!llvm.ptr")
834824
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
835-
yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree
825+
ctx = LaunchContext(launch_op, scratch_ptr, cluster, prof)
826+
with ctx.named_region("Init"):
827+
smem_ref_tree = _construct_smem_reftree(
828+
cluster, dynamic_smem, smem_buffers
829+
)
830+
# TODO(apaszke): Skip the following if no barriers were initialized.
831+
nvvm.fence_mbarrier_init()
832+
if math.prod(cluster) != 1:
833+
nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get())
834+
nvvm.cluster_wait(aligned=ir.UnitAttr.get())
835+
gpu.barrier()
836+
837+
yield ctx, smem_ref_tree
836838
if prof is not None:
837839
prof.finalize(grid=grid, block=block)
838840
gpu.terminator()

0 commit comments

Comments
 (0)