Skip to content
This repository was archived by the owner on Mar 12, 2024. It is now read-only.

Commit a54b778

Browse files
authored
add full-model gradient clipping to optimizer (#287)
1 parent 4e1a928 commit a54b778

File tree

1 file changed

+25
-34
lines changed

1 file changed

+25
-34
lines changed

d2/train_net.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import os
88
import sys
9+
import itertools
910

1011
# fmt: off
1112
sys.path.insert(1, os.path.join(sys.path[0], '..'))
@@ -21,7 +22,7 @@
2122
from detectron2.checkpoint import DetectionCheckpointer
2223
from detectron2.config import get_cfg
2324
from detectron2.data import MetadataCatalog, build_detection_train_loader
24-
from detectron2.engine import AutogradProfiler, DefaultTrainer, default_argument_parser, default_setup, launch
25+
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
2526
from detectron2.evaluation import COCOEvaluator, verify_results
2627

2728
from detectron2.solver.build import maybe_add_gradient_clipping
@@ -32,37 +33,6 @@ class Trainer(DefaultTrainer):
3233
Extension of the Trainer class adapted to DETR.
3334
"""
3435

35-
def __init__(self, cfg):
36-
"""
37-
Args:
38-
cfg (CfgNode):
39-
"""
40-
self.clip_norm_val = 0.0
41-
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
42-
if cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
43-
self.clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
44-
super().__init__(cfg)
45-
46-
def run_step(self):
47-
assert self.model.training, "[Trainer] model was changed to eval mode!"
48-
start = time.perf_counter()
49-
data = next(self._data_loader_iter)
50-
data_time = time.perf_counter() - start
51-
52-
loss_dict = self.model(data)
53-
losses = sum(loss_dict.values())
54-
self._detect_anomaly(losses, loss_dict)
55-
56-
metrics_dict = loss_dict
57-
metrics_dict["data_time"] = data_time
58-
self._write_metrics(metrics_dict)
59-
60-
self.optimizer.zero_grad()
61-
losses.backward()
62-
if self.clip_norm_val > 0.0:
63-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm_val)
64-
self.optimizer.step()
65-
6636
@classmethod
6737
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
6838
"""
@@ -100,11 +70,32 @@ def build_optimizer(cls, cfg, model):
10070
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
10171
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
10272

73+
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
74+
# detectron2 doesn't have full model gradient clipping now
75+
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
76+
enable = (
77+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
78+
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
79+
and clip_norm_val > 0.0
80+
)
81+
82+
class FullModelGradientClippingOptimizer(optim):
83+
def step(self, closure=None):
84+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
85+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
86+
super().step(closure=closure)
87+
88+
return FullModelGradientClippingOptimizer if enable else optim
89+
10390
optimizer_type = cfg.SOLVER.OPTIMIZER
10491
if optimizer_type == "SGD":
105-
optimizer = torch.optim.SGD(params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM)
92+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
93+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
94+
)
10695
elif optimizer_type == "ADAMW":
107-
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR)
96+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
97+
params, cfg.SOLVER.BASE_LR
98+
)
10899
else:
109100
raise NotImplementedError(f"no optimizer type {optimizer_type}")
110101
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":

0 commit comments

Comments
 (0)