@@ -98,8 +98,14 @@ def __init__(
98
98
verbose : bool = False ,
99
99
) -> None :
100
100
assert mixed_precision in (torch .float16 , torch .bfloat16 )
101
+ reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
102
+ self .enable_gradient_accumulation = enable_gradient_accumulation
101
103
if chunk_config_dict is not None :
102
- self .chunk_manager = ChunkManager (chunk_config_dict , chunk_init_device )
104
+ self .chunk_manager = ChunkManager (
105
+ chunk_config_dict ,
106
+ chunk_init_device ,
107
+ reuse_fp16_chunk = reuse_fp16_chunk ,
108
+ )
103
109
else :
104
110
# some ugly hotfix for the compatibility with Lightning
105
111
if search_range_m is None :
@@ -112,6 +118,7 @@ def __init__(
112
118
min_chunk_size_m = min_chunk_size_m ,
113
119
strict_ddp_flag = strict_ddp_mode ,
114
120
process_group = zero_group ,
121
+ reuse_fp16_chunk = reuse_fp16_chunk ,
115
122
verbose = verbose ,
116
123
)
117
124
self .gemini_manager = GeminiManager (
@@ -128,7 +135,6 @@ def __init__(
128
135
self .param_op_hook = GeminiZeROHook (self .gemini_manager )
129
136
self .fp32_params : List [torch .Tensor ] = list ()
130
137
self .fp16_params : List [ColoParameter ] = list ()
131
- self .overflow_counter = 0
132
138
self .grads_device : Dict [torch .Tensor , torch .device ] = dict ()
133
139
self .param2name : Dict [nn .Parameter , str ] = dict ()
134
140
self .name2param : Dict [str , nn .Parameter ] = dict ()
@@ -137,14 +143,8 @@ def __init__(
137
143
self .zero_group = zero_group or _get_default_group ()
138
144
self .extra_dp_group = extra_dp_group
139
145
140
- self .reuse_fp16_chunk = master_weights
141
146
self .master_weights = master_weights
142
147
143
- self .enable_gradient_accumulation = enable_gradient_accumulation
144
- if self .enable_gradient_accumulation :
145
- self .reuse_fp16_chunk = False
146
- self .accumulating_grads = False # Whether model is accumulating gradients
147
-
148
148
self ._logger = get_dist_logger ()
149
149
150
150
if self .gemini_manager ._premade_memstats_ :
@@ -178,7 +178,29 @@ def __init__(
178
178
if is_ddp_ignored (p ):
179
179
continue
180
180
if p .requires_grad :
181
- p .register_hook (partial (self .grad_handle , p ))
181
+ p ._grad_handle = p .register_hook (
182
+ partial (
183
+ GeminiDDP .grad_handle ,
184
+ chunk_manager = self .chunk_manager ,
185
+ param2name = self .param2name ,
186
+ grads_device = self .grads_device ,
187
+ master_weights = self .master_weights ,
188
+ enable_gradient_accumulation = self .enable_gradient_accumulation ,
189
+ p = p ,
190
+ )
191
+ )
192
+
193
+ def remove_hooks (self ):
194
+ for p in self .module .parameters ():
195
+ if is_ddp_ignored (p ):
196
+ continue
197
+ if p .requires_grad :
198
+ assert hasattr (p , "_grad_handle" )
199
+ p ._grad_handle .remove ()
200
+ delattr (p , "_grad_handle" )
201
+
202
+ def __del__ (self ):
203
+ self .remove_hooks ()
182
204
183
205
def parameters (self , recurse : bool = True ):
184
206
return self .module .parameters (recurse )
@@ -324,8 +346,8 @@ def _post_backward(self):
324
346
f"{ error_str } " ,
325
347
)
326
348
self ._setup_grads_ptr ()
327
- if self .enable_gradient_accumulation and not self .accumulating_grads :
328
- self .accumulating_grads = True # Turn on the state of gradient accumulation.
349
+ if self .enable_gradient_accumulation and not self .chunk_manager . accumulating_grads :
350
+ self .chunk_manager . accumulating_grads = True # Turn on the state of gradient accumulation.
329
351
self ._logger .debug (
330
352
f"comp cuda demand time: { self .gemini_manager ._comp_cuda_demand_time } , layout time: { self .gemini_manager ._layout_time } , evict time: { self .gemini_manager ._evict_time } , CPU->CUDA vol: { self .gemini_manager ._h2d_volume } B, CUDA->CPU vol: { self .gemini_manager ._d2h_volume } "
331
353
)
@@ -340,25 +362,34 @@ def backward(self, loss: torch.Tensor):
340
362
def backward_by_grad (self , tensor , grad ):
341
363
raise RuntimeError ("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini." )
342
364
343
- def grad_handle (self , p , grad ):
365
+ @staticmethod
366
+ def grad_handle (
367
+ grad ,
368
+ chunk_manager : ChunkManager ,
369
+ param2name : Dict ,
370
+ grads_device : Dict ,
371
+ master_weights : bool ,
372
+ enable_gradient_accumulation : bool ,
373
+ p : nn .Parameter ,
374
+ ):
344
375
setattr (p , "_gemini_reduced" , True )
345
376
empty_grad = torch .empty_like (grad )
346
377
free_storage (empty_grad )
347
378
with torch ._C .DisableTorchFunction ():
348
- chunk = self . chunk_manager .get_chunk (p )
379
+ chunk = chunk_manager .get_chunk (p )
349
380
if chunk .tensors_info [p ].state != TensorState .HOLD_AFTER_BWD :
350
381
raise RuntimeError (
351
- f"Parameter `{ self . param2name [p ]} ` failed at the gradient reduction. "
382
+ f"Parameter `{ param2name [p ]} ` failed at the gradient reduction. "
352
383
"Some unsupported torch function is operated upon this parameter."
353
384
)
354
385
grad_chunk = chunk
355
- if not self .reuse_fp16_chunk :
356
- if not self .accumulating_grads :
357
- grad_chunk = self . chunk_manager .init_grad_chunk (chunk )
386
+ if not chunk_manager .reuse_fp16_chunk :
387
+ if not chunk_manager .accumulating_grads :
388
+ grad_chunk = chunk_manager .init_grad_chunk (chunk )
358
389
else :
359
390
assert chunk .grad_chunk is not None
360
- if chunk .grad_chunk not in self . chunk_manager .accessed_chunks :
361
- grad_chunk = self . chunk_manager .rearrange_accumulated_grad_chunk (chunk )
391
+ if chunk .grad_chunk not in chunk_manager .accessed_chunks :
392
+ grad_chunk = chunk_manager .rearrange_accumulated_grad_chunk (chunk )
362
393
else :
363
394
grad_chunk = chunk .grad_chunk
364
395
chunk .grad_chunk .l2_norm = None
@@ -371,33 +402,33 @@ def grad_handle(self, p, grad):
371
402
chunk .tensor_trans_state (p , TensorState .HOLD )
372
403
373
404
grad_chunk .tensor_trans_state (p , TensorState .READY_FOR_REDUCE )
374
- if not self .accumulating_grads :
375
- grad_chunk .copy_tensor_to_chunk_slice (p , grad , update_ptr = self .reuse_fp16_chunk )
405
+ if not chunk_manager .accumulating_grads :
406
+ grad_chunk .copy_tensor_to_chunk_slice (p , grad , update_ptr = chunk_manager .reuse_fp16_chunk )
376
407
else :
377
408
grad_chunk .add_tensor_to_chunk_slice (p , grad )
378
- reduced = self . chunk_manager .reduce_chunk (grad_chunk )
409
+ reduced = chunk_manager .reduce_chunk (grad_chunk )
379
410
if reduced :
380
- if not self .reuse_fp16_chunk :
411
+ if not chunk_manager .reuse_fp16_chunk :
381
412
if chunk .keep_gathered :
382
- self . chunk_manager .fake_release_chunk (chunk )
413
+ chunk_manager .fake_release_chunk (chunk )
383
414
else :
384
- self . chunk_manager .release_chunk (chunk )
415
+ chunk_manager .release_chunk (chunk )
385
416
if grad_chunk .is_gathered :
386
417
grad_chunk .cuda_global_chunk .div_ (chunk .pg_size )
387
- if self .extra_dp_group is not None :
418
+ if chunk .extra_dp_group is not None :
388
419
grad_chunk .cuda_global_chunk .div_ (chunk .extra_dp_size )
389
420
else :
390
421
grad_chunk .cuda_shard .div_ (chunk .pg_size )
391
- if self .extra_dp_group is not None :
422
+ if chunk .extra_dp_group is not None :
392
423
grad_chunk .cuda_shard .div_ (chunk .extra_dp_size )
393
424
# check overflow elements
394
- self .overflow_counter += grad_chunk .has_inf_or_nan
425
+ chunk_manager .overflow_counter += grad_chunk .has_inf_or_nan
395
426
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
396
427
if chunk .l2_norm_flag :
397
428
grad_chunk .set_l2_norm ()
398
- self . chunk_manager .move_chunk (grad_chunk , self . grads_device [p ], force_copy = True )
399
- if not (self . master_weights ) or (self . enable_gradient_accumulation ):
400
- self . chunk_manager .move_chunk (chunk , self . grads_device [p ], force_copy = True )
429
+ chunk_manager .move_chunk (grad_chunk , grads_device [p ], force_copy = True )
430
+ if not (master_weights ) or (enable_gradient_accumulation ):
431
+ chunk_manager .move_chunk (chunk , grads_device [p ], force_copy = True )
401
432
return empty_grad
402
433
403
434
def zero_grad (self , set_to_none : bool = False ) -> None :
@@ -513,11 +544,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
513
544
514
545
# get copies of fp32 parameters in CPU
515
546
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
516
- params = self .fp32_params if self .reuse_fp16_chunk else self .fp16_params
547
+ params = self .fp32_params if self .chunk_manager . reuse_fp16_chunk else self .fp16_params
517
548
param_to_save_data = self ._get_param_to_save_data (params , only_rank_0 )
518
549
# get the mapping between copies and fp16 parameters
519
550
p_mapping = dict ()
520
- if self .reuse_fp16_chunk :
551
+ if self .chunk_manager . reuse_fp16_chunk :
521
552
for p , fp32_p in zip (self .fp16_params , self .fp32_params ):
522
553
name = self .param2name [p ]
523
554
assert fp32_p in param_to_save_data , "Parameter '{}' is neglected in the chunk list" .format (name )
@@ -713,7 +744,7 @@ def load_parameter(chunk_slice, data):
713
744
name = self .param2name [p ]
714
745
fp32_to_name [fp32_p ] = name
715
746
716
- params_to_load = self .fp32_params if self .reuse_fp16_chunk else self .fp16_params
747
+ params_to_load = self .fp32_params if self .chunk_manager . reuse_fp16_chunk else self .fp16_params
717
748
chunk_list = self .chunk_manager .get_chunks (params_to_load )
718
749
for chunk in chunk_list :
719
750
temp_chunk = get_temp_total_chunk_on_cuda (chunk , self .mixed_precision )
@@ -728,7 +759,9 @@ def load_parameter(chunk_slice, data):
728
759
shard_fn = tensor .shard_fn
729
760
gather_fn = tensor .gather_fn
730
761
731
- parameter_name = fp32_to_name [tensor ] if self .reuse_fp16_chunk else self .param2name [tensor ]
762
+ parameter_name = (
763
+ fp32_to_name [tensor ] if self .chunk_manager .reuse_fp16_chunk else self .param2name [tensor ]
764
+ )
732
765
parameter_slice = temp_chunk [tensor_info .offset : tensor_info .end ]
733
766
load (
734
767
parameter_name ,
@@ -900,7 +933,7 @@ def state_dict_shard(
900
933
gathered_param = param if keep_vars else param .detach ()
901
934
else :
902
935
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
903
- param_to_save = fp16_to_fp32 [param ] if self .reuse_fp16_chunk else param
936
+ param_to_save = fp16_to_fp32 [param ] if self .chunk_manager . reuse_fp16_chunk else param
904
937
if param_to_save not in gathered_param_buffer :
905
938
chunk = self .chunk_manager .get_chunk (param_to_save )
906
939
gathered_param_buffer .update (self ._get_chunk_to_save_data (chunk , only_rank_0 ))
0 commit comments