Skip to content

Commit d4c5ef4

Browse files
authored
[gemini]remove registered gradients hooks (#5696)
* fix gemini fix gemini * fix fix
1 parent 2229778 commit d4c5ef4

File tree

5 files changed

+93
-46
lines changed

5 files changed

+93
-46
lines changed

colossalai/zero/gemini/chunk/manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ class ChunkManager:
2020
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
2121
"""
2222

23-
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
23+
def __init__(
24+
self,
25+
chunk_configuration,
26+
init_device: Optional[torch.device] = None,
27+
reuse_fp16_chunk: bool = True,
28+
) -> None:
2429
self.device = init_device or get_accelerator().get_current_device()
2530
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
2631
self.kwargs_config = chunk_configuration
@@ -33,6 +38,10 @@ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = No
3338
self.accessed_chunks: Set[Chunk] = set()
3439
self.accessed_mem: int = 0
3540
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
41+
self.reuse_fp16_chunk = reuse_fp16_chunk
42+
# Whether model is accumulating gradients,
43+
self.accumulating_grads = False
44+
self.overflow_counter = 0
3645

3746
def register_tensor(
3847
self,

colossalai/zero/gemini/chunk/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def init_chunk_manager(
1919
model: nn.Module,
2020
init_device: Optional[torch.device] = None,
2121
hidden_dim: Optional[int] = None,
22+
reuse_fp16_chunk: bool = True,
2223
verbose: bool = False,
2324
**kwargs,
2425
) -> ChunkManager:
@@ -50,5 +51,9 @@ def init_chunk_manager(
5051
)
5152
dist.barrier()
5253

53-
chunk_manager = ChunkManager(config_dict, init_device)
54+
chunk_manager = ChunkManager(
55+
config_dict,
56+
init_device,
57+
reuse_fp16_chunk=reuse_fp16_chunk,
58+
)
5459
return chunk_manager

colossalai/zero/gemini/gemini_ddp.py

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,14 @@ def __init__(
9898
verbose: bool = False,
9999
) -> None:
100100
assert mixed_precision in (torch.float16, torch.bfloat16)
101+
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
102+
self.enable_gradient_accumulation = enable_gradient_accumulation
101103
if chunk_config_dict is not None:
102-
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
104+
self.chunk_manager = ChunkManager(
105+
chunk_config_dict,
106+
chunk_init_device,
107+
reuse_fp16_chunk=reuse_fp16_chunk,
108+
)
103109
else:
104110
# some ugly hotfix for the compatibility with Lightning
105111
if search_range_m is None:
@@ -112,6 +118,7 @@ def __init__(
112118
min_chunk_size_m=min_chunk_size_m,
113119
strict_ddp_flag=strict_ddp_mode,
114120
process_group=zero_group,
121+
reuse_fp16_chunk=reuse_fp16_chunk,
115122
verbose=verbose,
116123
)
117124
self.gemini_manager = GeminiManager(
@@ -128,7 +135,6 @@ def __init__(
128135
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
129136
self.fp32_params: List[torch.Tensor] = list()
130137
self.fp16_params: List[ColoParameter] = list()
131-
self.overflow_counter = 0
132138
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
133139
self.param2name: Dict[nn.Parameter, str] = dict()
134140
self.name2param: Dict[str, nn.Parameter] = dict()
@@ -137,14 +143,8 @@ def __init__(
137143
self.zero_group = zero_group or _get_default_group()
138144
self.extra_dp_group = extra_dp_group
139145

140-
self.reuse_fp16_chunk = master_weights
141146
self.master_weights = master_weights
142147

143-
self.enable_gradient_accumulation = enable_gradient_accumulation
144-
if self.enable_gradient_accumulation:
145-
self.reuse_fp16_chunk = False
146-
self.accumulating_grads = False # Whether model is accumulating gradients
147-
148148
self._logger = get_dist_logger()
149149

150150
if self.gemini_manager._premade_memstats_:
@@ -178,7 +178,29 @@ def __init__(
178178
if is_ddp_ignored(p):
179179
continue
180180
if p.requires_grad:
181-
p.register_hook(partial(self.grad_handle, p))
181+
p._grad_handle = p.register_hook(
182+
partial(
183+
GeminiDDP.grad_handle,
184+
chunk_manager=self.chunk_manager,
185+
param2name=self.param2name,
186+
grads_device=self.grads_device,
187+
master_weights=self.master_weights,
188+
enable_gradient_accumulation=self.enable_gradient_accumulation,
189+
p=p,
190+
)
191+
)
192+
193+
def remove_hooks(self):
194+
for p in self.module.parameters():
195+
if is_ddp_ignored(p):
196+
continue
197+
if p.requires_grad:
198+
assert hasattr(p, "_grad_handle")
199+
p._grad_handle.remove()
200+
delattr(p, "_grad_handle")
201+
202+
def __del__(self):
203+
self.remove_hooks()
182204

183205
def parameters(self, recurse: bool = True):
184206
return self.module.parameters(recurse)
@@ -324,8 +346,8 @@ def _post_backward(self):
324346
f"{error_str}",
325347
)
326348
self._setup_grads_ptr()
327-
if self.enable_gradient_accumulation and not self.accumulating_grads:
328-
self.accumulating_grads = True # Turn on the state of gradient accumulation.
349+
if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads:
350+
self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation.
329351
self._logger.debug(
330352
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
331353
)
@@ -340,25 +362,34 @@ def backward(self, loss: torch.Tensor):
340362
def backward_by_grad(self, tensor, grad):
341363
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
342364

343-
def grad_handle(self, p, grad):
365+
@staticmethod
366+
def grad_handle(
367+
grad,
368+
chunk_manager: ChunkManager,
369+
param2name: Dict,
370+
grads_device: Dict,
371+
master_weights: bool,
372+
enable_gradient_accumulation: bool,
373+
p: nn.Parameter,
374+
):
344375
setattr(p, "_gemini_reduced", True)
345376
empty_grad = torch.empty_like(grad)
346377
free_storage(empty_grad)
347378
with torch._C.DisableTorchFunction():
348-
chunk = self.chunk_manager.get_chunk(p)
379+
chunk = chunk_manager.get_chunk(p)
349380
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
350381
raise RuntimeError(
351-
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
382+
f"Parameter `{param2name[p]}` failed at the gradient reduction. "
352383
"Some unsupported torch function is operated upon this parameter."
353384
)
354385
grad_chunk = chunk
355-
if not self.reuse_fp16_chunk:
356-
if not self.accumulating_grads:
357-
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
386+
if not chunk_manager.reuse_fp16_chunk:
387+
if not chunk_manager.accumulating_grads:
388+
grad_chunk = chunk_manager.init_grad_chunk(chunk)
358389
else:
359390
assert chunk.grad_chunk is not None
360-
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
361-
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
391+
if chunk.grad_chunk not in chunk_manager.accessed_chunks:
392+
grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk)
362393
else:
363394
grad_chunk = chunk.grad_chunk
364395
chunk.grad_chunk.l2_norm = None
@@ -371,33 +402,33 @@ def grad_handle(self, p, grad):
371402
chunk.tensor_trans_state(p, TensorState.HOLD)
372403

373404
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
374-
if not self.accumulating_grads:
375-
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
405+
if not chunk_manager.accumulating_grads:
406+
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
376407
else:
377408
grad_chunk.add_tensor_to_chunk_slice(p, grad)
378-
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
409+
reduced = chunk_manager.reduce_chunk(grad_chunk)
379410
if reduced:
380-
if not self.reuse_fp16_chunk:
411+
if not chunk_manager.reuse_fp16_chunk:
381412
if chunk.keep_gathered:
382-
self.chunk_manager.fake_release_chunk(chunk)
413+
chunk_manager.fake_release_chunk(chunk)
383414
else:
384-
self.chunk_manager.release_chunk(chunk)
415+
chunk_manager.release_chunk(chunk)
385416
if grad_chunk.is_gathered:
386417
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
387-
if self.extra_dp_group is not None:
418+
if chunk.extra_dp_group is not None:
388419
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
389420
else:
390421
grad_chunk.cuda_shard.div_(chunk.pg_size)
391-
if self.extra_dp_group is not None:
422+
if chunk.extra_dp_group is not None:
392423
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
393424
# check overflow elements
394-
self.overflow_counter += grad_chunk.has_inf_or_nan
425+
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
395426
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
396427
if chunk.l2_norm_flag:
397428
grad_chunk.set_l2_norm()
398-
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
399-
if not (self.master_weights) or (self.enable_gradient_accumulation):
400-
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
429+
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
430+
if not (master_weights) or (enable_gradient_accumulation):
431+
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
401432
return empty_grad
402433

403434
def zero_grad(self, set_to_none: bool = False) -> None:
@@ -513,11 +544,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
513544

514545
# get copies of fp32 parameters in CPU
515546
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
516-
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
547+
params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
517548
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
518549
# get the mapping between copies and fp16 parameters
519550
p_mapping = dict()
520-
if self.reuse_fp16_chunk:
551+
if self.chunk_manager.reuse_fp16_chunk:
521552
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
522553
name = self.param2name[p]
523554
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
@@ -713,7 +744,7 @@ def load_parameter(chunk_slice, data):
713744
name = self.param2name[p]
714745
fp32_to_name[fp32_p] = name
715746

716-
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
747+
params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
717748
chunk_list = self.chunk_manager.get_chunks(params_to_load)
718749
for chunk in chunk_list:
719750
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
@@ -728,7 +759,9 @@ def load_parameter(chunk_slice, data):
728759
shard_fn = tensor.shard_fn
729760
gather_fn = tensor.gather_fn
730761

731-
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
762+
parameter_name = (
763+
fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor]
764+
)
732765
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
733766
load(
734767
parameter_name,
@@ -900,7 +933,7 @@ def state_dict_shard(
900933
gathered_param = param if keep_vars else param.detach()
901934
else:
902935
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
903-
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
936+
param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param
904937
if param_to_save not in gathered_param_buffer:
905938
chunk = self.chunk_manager.get_chunk(param_to_save)
906939
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))

colossalai/zero/gemini/gemini_optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def __init__(
6262
self.module = module
6363

6464
def check_local_overflow(self) -> bool:
65-
return self.module.overflow_counter > 0
65+
return self.module.chunk_manager.overflow_counter > 0
6666

6767
def pre_zero_grad(self) -> None:
68-
self.module.overflow_counter = 0
68+
self.module.chunk_manager.overflow_counter = 0
6969

7070

7171
class GeminiOptimizer(OptimizerWrapper):
@@ -202,7 +202,7 @@ def _set_grad_ptr(self):
202202
chunk16 = self.param_to_chunk16[fake_param]
203203
begin, end = self.param_to_range[fake_param]
204204

205-
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
205+
grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk
206206
fake_param.data = grad_chunk16.payload[begin:end]
207207
fake_param.grad = fake_param.data
208208

@@ -221,14 +221,14 @@ def _update_fp16_params(self):
221221

222222
def _clear_global_norm(self) -> None:
223223
for c16 in self.chunk16_set:
224-
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
224+
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
225225
grad_chunk.l2_norm = None
226226

227227
def _calc_global_norm(self) -> float:
228228
norm_sqr: float = 0.0
229229
group_to_norm = dict()
230230
for c16 in self.chunk16_set:
231-
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
231+
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
232232
assert grad_chunk.l2_norm is not None
233233

234234
if grad_chunk.is_gathered:
@@ -275,7 +275,7 @@ def step(self, *args, **kwargs):
275275
self._logger.info(f"Found overflow. Skip step")
276276
self._clear_global_norm() # clear recorded norm
277277
self.zero_grad() # reset all gradients
278-
if self.module.reuse_fp16_chunk:
278+
if self.module.chunk_manager.reuse_fp16_chunk:
279279
self._update_fp16_params()
280280
return
281281

@@ -288,7 +288,7 @@ def step(self, *args, **kwargs):
288288
self.zero_grad()
289289
if self.module.master_weights:
290290
self._update_fp16_params()
291-
self.module.accumulating_grads = False
291+
self.module.chunk_manager.accumulating_grads = False
292292
return ret
293293

294294
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):

tests/test_zero/test_gemini/test_fwd_bwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
2626
chunk_manager = model.chunk_manager
2727
param_list = [p for p in model.parameters()]
2828
chunk_list = chunk_manager.get_chunks(param_list)
29-
if not model.reuse_fp16_chunk:
29+
if not model.chunk_manager.reuse_fp16_chunk:
3030
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
3131
for chunk in chunk_list:
3232
chunk_manager.access_chunk(chunk)

0 commit comments

Comments
 (0)