@@ -806,16 +806,6 @@ def _launch(
806806 )
807807 )
808808
809- smem_ref_tree = _construct_smem_reftree (
810- cluster , dynamic_smem , smem_buffers
811- )
812- # TODO(apaszke): Skip the following if no barriers were initialized.
813- nvvm .fence_mbarrier_init ()
814- if math .prod (cluster ) != 1 :
815- nvvm .cluster_arrive_relaxed (aligned = ir .UnitAttr .get ())
816- nvvm .cluster_wait (aligned = ir .UnitAttr .get ())
817- gpu .barrier ()
818-
819809 if profiler_spec :
820810 prof_smem = memref .view (
821811 ir .MemRefType .get (
@@ -832,7 +822,19 @@ def _launch(
832822
833823 ptr_ty = ir .Type .parse ("!llvm.ptr" )
834824 scratch_ptr = builtin .unrealized_conversion_cast ([ptr_ty ], [scratch_arr ])
835- yield LaunchContext (launch_op , scratch_ptr , cluster , prof ), smem_ref_tree
825+ ctx = LaunchContext (launch_op , scratch_ptr , cluster , prof )
826+ with ctx .named_region ("Init" ):
827+ smem_ref_tree = _construct_smem_reftree (
828+ cluster , dynamic_smem , smem_buffers
829+ )
830+ # TODO(apaszke): Skip the following if no barriers were initialized.
831+ nvvm .fence_mbarrier_init ()
832+ if math .prod (cluster ) != 1 :
833+ nvvm .cluster_arrive_relaxed (aligned = ir .UnitAttr .get ())
834+ nvvm .cluster_wait (aligned = ir .UnitAttr .get ())
835+ gpu .barrier ()
836+
837+ yield ctx , smem_ref_tree
836838 if prof is not None :
837839 prof .finalize (grid = grid , block = block )
838840 gpu .terminator ()
0 commit comments