@@ -110,12 +110,8 @@ def __init__(self,
110110 @classmethod
111111 def create (cls , zenflow_config ):
112112 if zenflow_config .overlap_step :
113- # print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
114- print ("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerParallel" )
115113 return ZenFlowZeroOptimizerParallel
116114 else :
117- # print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
118- print ("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerSequential" )
119115 return ZenFlowZeroOptimizerSequential
120116
121117 def _configure_zenflow (self , zenflow_config ):
@@ -182,7 +178,7 @@ def sync_fp32_param_from_gpu(self):
182178 fp32_partition .copy_ (bit16_partitions [partition_id ].to (dtype = fp32_partition .dtype ,
183179 device = fp32_partition .device ))
184180
185- def update_selected_channels (self , tensor , total_size ):
181+ def update_selected_channels (self , tensor , total_size , communication_data_type ):
186182 curr_size = 0
187183 curr_index_buffer_size = 0
188184 rank_and_offsets = []
@@ -194,7 +190,8 @@ def update_selected_channels(self, tensor, total_size):
194190 self .index_buffer = torch .empty (total_size , dtype = torch .int32 , device = 'cuda' )
195191
196192 # count = 0
197- for i , param_idx_in_group , param_id in self .params_in_ipg_bucket :
193+ bucket = self .ipg_buckets [communication_data_type ]
194+ for i , param_idx_in_group , param_id in bucket .params :
198195 param = self .bit16_groups [i ][param_idx_in_group ]
199196
200197 if len (param .shape ) == 1 :
@@ -255,7 +252,7 @@ def update_selected_channels(self, tensor, total_size):
255252 index_slice = self .index_buffer .narrow (0 , offset , num_select )
256253 dist .broadcast (index_slice , src = src_rank , group = process_group )
257254
258- for i , param_idx_in_group , param_id in self . params_in_ipg_bucket :
255+ for i , param_idx_in_group , param_id in bucket . params :
259256 param = self .bit16_groups [i ][param_idx_in_group ]
260257
261258 if len (param .shape ) == 1 :
@@ -281,15 +278,15 @@ def update_selected_channels(self, tensor, total_size):
281278
282279 self .index_buffer = None
283280
284- def process_selected_fp32_groups_grad (self , tensor , total_size ):
281+ def _process_selected_fp32_groups_grad (self , tensor , total_size , communication_data_type ):
285282 """
286283 Process gradients for selected columns in FP32 groups
287284
288285 Args:
289286 param: The parameter to process
290287 param_id: ID of the parameter
291288 """
292- print ( "Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!process_selected_fp32_groups_grad" )
289+
293290 curr_size = 0
294291 curr_grad_buffer_size = 0
295292 curr_sum_buffer_size = 0
@@ -309,7 +306,8 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
309306 group_to_paramlist = {}
310307
311308 # count = 0
312- for i , param_idx_in_group , param_id in self .params_in_ipg_bucket :
309+ bucket = self .ipg_buckets [communication_data_type ]
310+ for i , param_idx_in_group , param_id in bucket .params :
313311 param = self .bit16_groups [i ][param_idx_in_group ]
314312
315313 if not hasattr (param , 'selected_indices' ):
@@ -389,7 +387,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
389387 sum_slice = self .sum_buffer .narrow (0 , sum_offset , sum_num )
390388 dist .broadcast (sum_slice , src = src_rank , group = process_group )
391389
392- for i , param_idx_in_group , param_id in self . params_in_ipg_bucket :
390+ for i , param_idx_in_group , param_id in bucket . params :
393391 param = self .bit16_groups [i ][param_idx_in_group ]
394392
395393 selected_grad = None
@@ -450,7 +448,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
450448 if self .auto_update :
451449 self .sum_buffer = None
452450
453- def average_tensor (self , tensor ):
451+ def average_tensor (self , tensor : torch . Tensor , communication_data_type : torch . dtype ):
454452 if self .overlap_comm :
455453 stream = self .reduction_stream
456454 if not get_accelerator ().resolves_data_dependency ():
@@ -478,12 +476,13 @@ def average_tensor(self, tensor):
478476
479477 process_group = self .dp_process_group
480478 # count = 0
481- for i , param_idx_in_group , param_id in self .params_in_ipg_bucket :
479+ bucket = self .ipg_buckets [communication_data_type ]
480+ for i , param_idx_in_group , param_id in bucket .params :
482481 param = self .bit16_groups [i ][param_idx_in_group ]
483482
484483 process_group = self .dp_process_group
485484
486- if self . ipg_bucket_has_moe_params :
485+ if bucket . has_moe_params :
487486 process_group = self .expert_dp_process_group [param .group_name ] if is_moe_param (
488487 param ) else self .dp_process_group
489488
@@ -546,12 +545,14 @@ def average_tensor(self, tensor):
546545 for bucket_key in buckets :
547546 if self .use_multi_rank_bucket_allreduce :
548547 self .allreduce_and_scatter (buckets [bucket_key ],
548+ communication_data_type ,
549549 numel_per_bucket = self .reduce_bucket_size ,
550550 divide = False ,
551551 process_group = bucket_key )
552552 else :
553553 dst , process_group = bucket_key
554554 self .allreduce_no_retain (buckets [bucket_key ],
555+ communication_data_type ,
555556 numel_per_bucket = self .reduce_bucket_size ,
556557 rank = dst ,
557558 divide = False ,
@@ -560,15 +561,15 @@ def average_tensor(self, tensor):
560561 if self .is_zenflow_select_boundary ():
561562 self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).start ()
562563 # print("update selected")
563- self .update_selected_channels (tensor , curr_column_size )
564+ self .update_selected_channels (tensor , curr_column_size , communication_data_type )
564565 self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).stop ()
565566 elif self .zenflow :
566567 self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).start ()
567568 self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).stop ()
568569
569570 if self .zenflow and self .micro_step >= self .full_warm_up_rounds :
570571 self .timers (SELECTIVE_OPTIMIZER_PROCESS_TIMER ).start ()
571- self .process_selected_fp32_groups_grad (tensor , curr_selected_reduce_size )
572+ self ._process_selected_fp32_groups_grad (tensor , curr_selected_reduce_size , communication_data_type )
572573 self .timers (SELECTIVE_OPTIMIZER_PROCESS_TIMER ).stop ()
573574
574575 def backward (self , loss , retain_graph = False ):
0 commit comments