@@ -96,6 +96,7 @@ def __init__(
96
96
master_weights : bool = True ,
97
97
extra_dp_group : Optional [ProcessGroup ] = None ,
98
98
verbose : bool = False ,
99
+ enable_async_reduce : bool = True ,
99
100
) -> None :
100
101
assert mixed_precision in (torch .float16 , torch .bfloat16 )
101
102
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@@ -178,6 +179,7 @@ def __init__(
178
179
if is_ddp_ignored (p ):
179
180
continue
180
181
if p .requires_grad :
182
+ assert not hasattr (p , "_grad_handle" )
181
183
p ._grad_handle = p .register_hook (
182
184
partial (
183
185
GeminiDDP .grad_handle ,
@@ -187,6 +189,7 @@ def __init__(
187
189
master_weights = self .master_weights ,
188
190
enable_gradient_accumulation = self .enable_gradient_accumulation ,
189
191
p = p ,
192
+ async_reduce = enable_async_reduce ,
190
193
)
191
194
)
192
195
@@ -334,6 +337,11 @@ def _pre_backward(self):
334
337
setattr (param , "_gemini_reduced" , False )
335
338
336
339
def _post_backward (self ):
340
+ for param in self .param2name :
341
+ if hasattr (param , "_release_grad_chunk_cb" ):
342
+ param ._release_grad_chunk_cb ()
343
+ delattr (param , "_release_grad_chunk_cb" )
344
+
337
345
if self .chunk_manager .accessed_mem != 0 :
338
346
error_params = ["Reduction failed at followed parameters:" ]
339
347
for param in self .param2name :
@@ -371,6 +379,7 @@ def grad_handle(
371
379
master_weights : bool ,
372
380
enable_gradient_accumulation : bool ,
373
381
p : nn .Parameter ,
382
+ async_reduce : bool ,
374
383
):
375
384
setattr (p , "_gemini_reduced" , True )
376
385
empty_grad = torch .empty_like (grad )
@@ -406,31 +415,57 @@ def grad_handle(
406
415
grad_chunk .copy_tensor_to_chunk_slice (p , grad , update_ptr = chunk_manager .reuse_fp16_chunk )
407
416
else :
408
417
grad_chunk .add_tensor_to_chunk_slice (p , grad )
409
- reduced = chunk_manager .reduce_chunk (grad_chunk )
410
- if reduced :
411
- if not chunk_manager .reuse_fp16_chunk :
412
- if chunk .keep_gathered :
413
- chunk_manager .fake_release_chunk (chunk )
414
- else :
415
- chunk_manager .release_chunk (chunk )
416
- if grad_chunk .is_gathered :
417
- grad_chunk .cuda_global_chunk .div_ (chunk .pg_size )
418
- if chunk .extra_dp_group is not None :
419
- grad_chunk .cuda_global_chunk .div_ (chunk .extra_dp_size )
418
+ reduced = chunk_manager .reduce_chunk (grad_chunk , async_op = async_reduce )
419
+ if reduced : # if not async, can release immediately, else release in when work finished
420
+ if async_reduce :
421
+ # dirty fix by installing callback
422
+ assert not hasattr (p , "_release_grad_chunk_cb" )
423
+
424
+ def _release_grad_chunk_cb ():
425
+ grad_chunk .wait_async_reduce ()
426
+ GeminiDDP .release_grad_chunk_handle (
427
+ chunk_manager ,
428
+ grads_device ,
429
+ master_weights ,
430
+ enable_gradient_accumulation ,
431
+ p ,
432
+ chunk ,
433
+ grad_chunk ,
434
+ )
435
+
436
+ p ._release_grad_chunk_cb = _release_grad_chunk_cb
420
437
else :
421
- grad_chunk .cuda_shard .div_ (chunk .pg_size )
422
- if chunk .extra_dp_group is not None :
423
- grad_chunk .cuda_shard .div_ (chunk .extra_dp_size )
424
- # check overflow elements
425
- chunk_manager .overflow_counter += grad_chunk .has_inf_or_nan
426
- # record l2 norm for gradient clipping. flag is bound to fp16 chunk
427
- if chunk .l2_norm_flag :
428
- grad_chunk .set_l2_norm ()
429
- chunk_manager .move_chunk (grad_chunk , grads_device [p ], force_copy = True )
430
- if not (master_weights ) or (enable_gradient_accumulation ):
431
- chunk_manager .move_chunk (chunk , grads_device [p ], force_copy = True )
438
+ GeminiDDP .release_grad_chunk_handle (
439
+ chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
440
+ )
432
441
return empty_grad
433
442
443
+ @staticmethod
444
+ def release_grad_chunk_handle (
445
+ chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
446
+ ):
447
+ if not chunk_manager .reuse_fp16_chunk :
448
+ if chunk .keep_gathered :
449
+ chunk_manager .fake_release_chunk (chunk )
450
+ else :
451
+ chunk_manager .release_chunk (chunk )
452
+ if grad_chunk .is_gathered :
453
+ grad_chunk .cuda_global_chunk .div_ (chunk .pg_size )
454
+ if chunk .extra_dp_group is not None :
455
+ grad_chunk .cuda_global_chunk .div_ (chunk .extra_dp_size )
456
+ else :
457
+ grad_chunk .cuda_shard .div_ (chunk .pg_size )
458
+ if chunk .extra_dp_group is not None :
459
+ grad_chunk .cuda_shard .div_ (chunk .extra_dp_size )
460
+ # check overflow elements
461
+ chunk_manager .overflow_counter += grad_chunk .has_inf_or_nan
462
+ # record l2 norm for gradient clipping. flag is bound to fp16 chunk
463
+ if chunk .l2_norm_flag :
464
+ grad_chunk .set_l2_norm ()
465
+ chunk_manager .move_chunk (grad_chunk , grads_device [p ], force_copy = True )
466
+ if not (master_weights ) or (enable_gradient_accumulation ):
467
+ chunk_manager .move_chunk (chunk , grads_device [p ], force_copy = True )
468
+
434
469
def zero_grad (self , set_to_none : bool = False ) -> None :
435
470
self .module .zero_grad (set_to_none = True )
436
471
0 commit comments