@@ -1524,6 +1524,8 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
15241524 (world_sz - len (params ) % world_sz ) % world_sz )
15251525 gathered_momentums_pad = gathered_params_momentums + [torch .empty_like (gathered_params_momentums [- 1 ])] * (
15261526 (world_sz - len (gathered_params_momentums ) % world_sz ) % world_sz )
1527+ grad_handles = []
1528+ momentum_handles = []
15271529 for base_i in range (len (params ))[::world_sz ]:
15281530 if base_i + rank < len (params ):
15291531 param = params [base_i + rank ]
@@ -1534,10 +1536,14 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
15341536 _ , _ , grad_offset = group_items [base_i + rank ]
15351537 buffer_to_reduce .narrow (0 , grad_offset ,
15361538 param .grad .numel ()).data .copy_ (g .view (- 1 ), non_blocking = False )
1537- dist .all_gather (grads_pad [base_i :base_i + world_sz ], grads_pad [base_i + rank ])
1538- dist .all_gather (gathered_momentums_pad [base_i :base_i + world_sz ],
1539- gathered_momentums_pad [base_i + rank ])
1540-
1539+ grad_handle = dist .all_gather (grads_pad [base_i :base_i + world_sz ], grads_pad [base_i + rank ], async_op = True )
1540+ grad_handles .append (grad_handle )
1541+ momentum_handle = dist .all_gather (gathered_momentums_pad [base_i :base_i + world_sz ],
1542+ gathered_momentums_pad [base_i + rank ], async_op = True )
1543+ momentum_handles .append (momentum_handle )
1544+
1545+ for handle in momentum_handles :
1546+ handle .wait ()
15411547 for idx , (param , dest_offset , _ ) in enumerate (group_items ):
15421548 gathered_momentum = gathered_params_momentums [idx ]
15431549 chunk_sz = math .ceil (param .grad .numel () / world_sz )
@@ -1564,6 +1570,11 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
15641570 )
15651571 if self ._swappable_optimizer_subgroup (i ) and not self .save_muon_momentum_buffer_in_memory :
15661572 self .optimizer_swapper .swap_out_optimizer_state (parameter = self .fp32_partitioned_groups_flat [i ])
1573+ for handle in grad_handles :
1574+ handle .wait ()
1575+ for param , _ , params_size_offset in group_items :
1576+ buffer_to_reduce .narrow (0 , params_size_offset ,
1577+ param .grad .numel ()).data .copy_ (param .grad .view (- 1 ), non_blocking = False )
15671578
15681579 @instrument_w_nvtx
15691580 def __avg_scatter_contiguous_grads (self , buffer_to_reduce : Tensor ,
0 commit comments