@@ -118,7 +118,6 @@ def clear(self):
118118 self .params .clear ()
119119 self .grads .clear ()
120120 self .elements = 0
121- self .index = 0
122121 self .has_moe_params = False
123122
124123
@@ -1052,11 +1051,8 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
10521051 bucket = self .ipg_buckets [comm_dtype ]
10531052 if bucket .elements + param .numel () > self .reduce_bucket_size :
10541053 self .report_ipg_memory_usage ("In ipg_remove_grads before reduce_ipg_grads" , param .numel ())
1055- self .reduce_ipg_grads ()
1054+ self .reduce_ipg_grads (comm_dtype = comm_dtype )
10561055 if self .contiguous_gradients and self .overlap_comm :
1057- if not get_accelerator ().resolves_data_dependency ():
1058- self .reduction_stream .wait_stream (get_accelerator ().current_stream ())
1059- get_accelerator ().current_stream ().wait_stream (self .reduction_stream )
10601056 # Swap index between 0 and 1
10611057 bucket .index = 1 - bucket .index
10621058 self .report_ipg_memory_usage ("In ipg_remove_grads after reduce_ipg_grads" , param .numel ())
@@ -1500,8 +1496,11 @@ def copy_grads_in_partition(self, param):
15001496 #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
15011497 self .grads_in_partition_offset += param .numel ()
15021498
1503- def reduce_ipg_grads (self ):
1504- for comm_dtype in sort_dtypes (self .ipg_buckets .keys ()):
1499+ def reduce_ipg_grads (self , comm_dtype = None ):
1500+ dtypes = sort_dtypes (self .ipg_buckets .keys ())
1501+ if comm_dtype is not None :
1502+ dtypes = [comm_dtype ]
1503+ for comm_dtype in dtypes :
15051504 bucket = self .ipg_buckets [comm_dtype ]
15061505
15071506 if self .contiguous_gradients :
@@ -1536,7 +1535,7 @@ def reduce_ipg_grads(self):
15361535 stream = get_accelerator ().current_stream ()
15371536
15381537 with get_accelerator ().stream (stream ):
1539- for comm_dtype in sort_dtypes ( self . ipg_buckets . keys ()) :
1538+ for comm_dtype in dtypes :
15401539 bucket = self .ipg_buckets [comm_dtype ]
15411540
15421541 for group_idx , param_idx_in_group , param_id in bucket .params :
0 commit comments