Skip to content

Commit a15ab13

Browse files
authored
[plugin] support get_grad_norm (#6115)
1 parent 13ffa08 commit a15ab13

File tree

8 files changed

+40
-2
lines changed

8 files changed

+40
-2
lines changed

colossalai/amp/naive_amp/mixed_precision_optimizer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple
1+
from typing import Dict, List, Optional, Tuple
22

33
import torch
44
from torch import Tensor, inf
@@ -84,6 +84,7 @@ def __init__(
8484
self.master_to_working_map[master_p] = p
8585
master_params.append(master_p)
8686
group["params"] = master_params
87+
self._current_grad_norm: Optional[float] = None
8788

8889
def backward(self, loss: Tensor, *args, **kwargs):
8990
loss = self.mixed_precision.pre_backward(loss)
@@ -187,6 +188,7 @@ def step(self, *args, **kwargs):
187188
if p.grad is not None
188189
]
189190
total_norm = self._compute_grad_norm(param_gradient_pairs)
191+
self._current_grad_norm = total_norm
190192
self._unscale_and_clip_grads(total_norm)
191193

192194
self.optim.step(*args, **kwargs)
@@ -212,3 +214,6 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
212214

213215
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
214216
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
217+
218+
def get_grad_norm(self, norm_type=2, **kwargs):
219+
return self._current_grad_norm

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def __init__(
293293
self.pp_pg = pp_process_group
294294
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
295295
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
296+
self._current_grad_norm: Optional[float] = None
296297
super().__init__(optim)
297298

298299
def backward(self, loss: Tensor, *args, **kwargs):
@@ -364,6 +365,7 @@ def step(self, *args, **kwargs):
364365
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
365366
]
366367
total_norm = self._compute_grad_norm(param_gradient_pairs)
368+
self._current_grad_norm = total_norm
367369

368370
# Clip the gradients to prevent exploding gradients.
369371
self._clip_grad_norm(total_norm)
@@ -477,6 +479,9 @@ def get_working_to_master_map(self):
477479
def get_master_to_working_map(self):
478480
return None
479481

482+
def get_grad_norm(self, norm_type=2, **kwargs):
483+
return self._current_grad_norm
484+
480485

481486
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
482487
def __init__(

colossalai/interface/optimizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,18 @@ def unwrap(self):
135135
"""
136136
return self.optim
137137

138+
def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
139+
"""
140+
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
141+
142+
Args:
143+
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
144+
145+
Returns:
146+
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
147+
"""
148+
raise NotImplementedError("The method get_grad_norm is not implemented yet.")
149+
138150

139151
class DistributedOptim(Optimizer):
140152
def setup_distributed(

colossalai/zero/gemini/gemini_optimizer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
22
import copy
33
import math
4-
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
4+
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union
55

66
import torch
77
import torch.distributed as dist
@@ -195,6 +195,7 @@ def __init__(
195195
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
196196

197197
self._register_states = disposable(self._register_states_)
198+
self._current_grad_norm: Optional[float] = None
198199

199200
def _set_grad_ptr(self):
200201
for group in self.param_groups:
@@ -255,6 +256,7 @@ def _get_combined_scale(self):
255256

256257
if self.clipping_flag:
257258
total_norm = self._calc_global_norm()
259+
self._current_grad_norm = total_norm
258260
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
259261
if clip > 1:
260262
div_scale = clip * div_scale
@@ -846,6 +848,9 @@ def clip_grad_by_norm(
846848
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
847849
)
848850

851+
def get_grad_norm(self, norm_type=2, **kwargs):
852+
return self._current_grad_norm
853+
849854

850855
class GeminiAdamOptimizer(GeminiOptimizer):
851856
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:

colossalai/zero/low_level/low_level_optim.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def __init__(
218218
)
219219
elif self._dtype is torch.bfloat16:
220220
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
221+
self._current_grad_norm: Optional[float] = None
221222

222223
def __del__(self):
223224
for hook in self.grad_handles:
@@ -551,6 +552,7 @@ def step(self, closure=None):
551552

552553
# unscale and clip grads
553554
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
555+
self._current_grad_norm = global_norm
554556
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
555557

556558
# update the parameters
@@ -934,3 +936,6 @@ def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) ->
934936
def _force_wait_all_gather(self):
935937
for param in self._working_param_to_padded_working_param.keys():
936938
wait_all_gather_handle(param)
939+
940+
def get_grad_norm(self, norm_type=2, **kwargs):
941+
return self._current_grad_norm

tests/test_booster/test_plugin/test_3d_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def _criterion(outputs, inputs):
7676

7777
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
7878
optimizer.step()
79+
grad_norm = optimizer.get_grad_norm()
80+
assert grad_norm is None or isinstance(grad_norm, float)
7981

8082
except Exception as e:
8183
return repr(e)

tests/test_booster/test_plugin/test_gemini_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
5454

5555
booster.backward(loss, optimizer)
5656
optimizer.step()
57+
grad_norm = optimizer.get_grad_norm()
58+
assert grad_norm is None or isinstance(grad_norm, float)
5759

5860
except NotImplementedError:
5961
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")

tests/test_booster/test_plugin/test_low_level_zero_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
5050

5151
booster.backward(loss, optimizer)
5252
optimizer.step()
53+
grad_norm = optimizer.get_grad_norm()
54+
assert grad_norm is None or isinstance(grad_norm, float)
5355

5456
except Exception as e:
5557
return repr(e)

0 commit comments

Comments
 (0)