@@ -383,6 +383,7 @@ def group_cast(
383383 post_perm_idx : torch .Tensor | None = None ,
384384 config : GrpCollConfig | None = None ,
385385 previous_event : EventOverlap | None = None ,
386+ kernel_barrier = None ,
386387 async_op : bool = False ,
387388 allocate_on_comm_stream : bool = False ,
388389 cast_lse : bool = False ,
@@ -495,6 +496,7 @@ def group_cast(
495496 is_token_in_rank = is_token_in_rank ,
496497 post_perm_idx = post_perm_idx ,
497498 previous_event = previous_event ,
499+ kernel_barrier = kernel_barrier ,
498500 async_op = async_op ,
499501 allocate_on_comm_stream = allocate_on_comm_stream ,
500502 cast_lse = cast_lse ,
@@ -514,6 +516,7 @@ def group_cast(
514516 is_token_in_rank = is_token_in_rank ,
515517 post_perm_idx = post_perm_idx ,
516518 previous_event = previous_event ,
519+ kernel_barrier = kernel_barrier ,
517520 async_op = async_op ,
518521 allocate_on_comm_stream = allocate_on_comm_stream ,
519522 cast_lse = cast_lse ,
@@ -531,6 +534,7 @@ def group_reduce(
531534 pre_perm_idx : torch .Tensor | None = None ,
532535 config : GrpCollConfig | None = None ,
533536 previous_event : EventOverlap | None = None ,
537+ kernel_barrier = None ,
534538 async_op : bool = False ,
535539 allocate_on_comm_stream : bool = False ,
536540 comm_dtype : torch .dtype | None = None ,
@@ -625,6 +629,7 @@ def group_reduce(
625629 acc_reduce = acc_reduce ,
626630 pre_perm_idx = pre_perm_idx ,
627631 previous_event = previous_event ,
632+ kernel_barrier = kernel_barrier ,
628633 async_op = async_op ,
629634 allocate_on_comm_stream = allocate_on_comm_stream ,
630635 comm_dtype = comm_dtype ,
@@ -643,6 +648,7 @@ def group_reduce(
643648 acc_reduce = acc_reduce ,
644649 pre_perm_idx = pre_perm_idx ,
645650 previous_event = previous_event ,
651+ kernel_barrier = kernel_barrier ,
646652 async_op = async_op ,
647653 allocate_on_comm_stream = allocate_on_comm_stream ,
648654 comm_dtype = comm_dtype ,
@@ -661,6 +667,7 @@ def _intranode_group_cast(
661667 is_token_in_rank : torch .Tensor | None = None ,
662668 post_perm_idx : torch .Tensor | None = None ,
663669 previous_event : EventOverlap | None = None ,
670+ kernel_barrier = None ,
664671 async_op : bool = False ,
665672 allocate_on_comm_stream : bool = False ,
666673 cast_lse : bool = False ,
@@ -747,6 +754,7 @@ def _intranode_group_cast(
747754 post_perm_idx ,
748755 config .to_kernel_config (),
749756 getattr (previous_event , "event" , None ),
757+ kernel_barrier ,
750758 async_op ,
751759 allocate_on_comm_stream ,
752760 )
@@ -791,6 +799,7 @@ def _intranode_group_reduce(
791799 acc_reduce : bool = False ,
792800 pre_perm_idx : torch .Tensor | None = None ,
793801 previous_event : EventOverlap | None = None ,
802+ kernel_barrier = None ,
794803 async_op : bool = False ,
795804 allocate_on_comm_stream : bool = False ,
796805 comm_dtype : torch .dtype | None = None ,
@@ -843,6 +852,7 @@ def _intranode_group_reduce(
843852 pre_perm_idx ,
844853 config .to_kernel_config (),
845854 getattr (previous_event , "event" , None ),
855+ kernel_barrier ,
846856 async_op ,
847857 allocate_on_comm_stream ,
848858 reduce_op ,
@@ -873,6 +883,7 @@ def _internode_group_cast(
873883 is_token_in_rank : torch .Tensor | None = None ,
874884 post_perm_idx : torch .Tensor | None = None ,
875885 previous_event : EventOverlap | None = None ,
886+ kernel_barrier = None ,
876887 async_op : bool = False ,
877888 allocate_on_comm_stream : bool = False ,
878889 cast_lse : bool = False ,
@@ -975,6 +986,7 @@ def _internode_group_cast(
975986 post_perm_idx ,
976987 config .to_kernel_config (),
977988 getattr (previous_event , "event" , None ),
989+ kernel_barrier ,
978990 async_op ,
979991 allocate_on_comm_stream ,
980992 )
@@ -1023,6 +1035,7 @@ def _internode_group_reduce(
10231035 acc_reduce : bool = False ,
10241036 pre_perm_idx : torch .Tensor | None = None ,
10251037 previous_event : EventOverlap | None = None ,
1038+ kernel_barrier = None ,
10261039 async_op : bool = False ,
10271040 allocate_on_comm_stream : bool = False ,
10281041 comm_dtype : torch .dtype | None = None ,
@@ -1078,6 +1091,7 @@ def _internode_group_reduce(
10781091 pre_perm_idx ,
10791092 config .to_kernel_config (),
10801093 getattr (previous_event , "event" , None ),
1094+ kernel_barrier ,
10811095 async_op ,
10821096 allocate_on_comm_stream ,
10831097 reduce_op ,
0 commit comments