66"""
77import os
88import sys
9+ import itertools
910
1011# fmt: off
1112sys .path .insert (1 , os .path .join (sys .path [0 ], '..' ))
2122from detectron2 .checkpoint import DetectionCheckpointer
2223from detectron2 .config import get_cfg
2324from detectron2 .data import MetadataCatalog , build_detection_train_loader
24- from detectron2 .engine import AutogradProfiler , DefaultTrainer , default_argument_parser , default_setup , launch
25+ from detectron2 .engine import DefaultTrainer , default_argument_parser , default_setup , launch
2526from detectron2 .evaluation import COCOEvaluator , verify_results
2627
2728from detectron2 .solver .build import maybe_add_gradient_clipping
@@ -32,37 +33,6 @@ class Trainer(DefaultTrainer):
3233 Extension of the Trainer class adapted to DETR.
3334 """
3435
35- def __init__ (self , cfg ):
36- """
37- Args:
38- cfg (CfgNode):
39- """
40- self .clip_norm_val = 0.0
41- if cfg .SOLVER .CLIP_GRADIENTS .ENABLED :
42- if cfg .SOLVER .CLIP_GRADIENTS .CLIP_TYPE == "full_model" :
43- self .clip_norm_val = cfg .SOLVER .CLIP_GRADIENTS .CLIP_VALUE
44- super ().__init__ (cfg )
45-
46- def run_step (self ):
47- assert self .model .training , "[Trainer] model was changed to eval mode!"
48- start = time .perf_counter ()
49- data = next (self ._data_loader_iter )
50- data_time = time .perf_counter () - start
51-
52- loss_dict = self .model (data )
53- losses = sum (loss_dict .values ())
54- self ._detect_anomaly (losses , loss_dict )
55-
56- metrics_dict = loss_dict
57- metrics_dict ["data_time" ] = data_time
58- self ._write_metrics (metrics_dict )
59-
60- self .optimizer .zero_grad ()
61- losses .backward ()
62- if self .clip_norm_val > 0.0 :
63- torch .nn .utils .clip_grad_norm_ (self .model .parameters (), self .clip_norm_val )
64- self .optimizer .step ()
65-
6636 @classmethod
6737 def build_evaluator (cls , cfg , dataset_name , output_folder = None ):
6838 """
@@ -100,11 +70,32 @@ def build_optimizer(cls, cfg, model):
10070 lr = lr * cfg .SOLVER .BACKBONE_MULTIPLIER
10171 params += [{"params" : [value ], "lr" : lr , "weight_decay" : weight_decay }]
10272
73+ def maybe_add_full_model_gradient_clipping (optim ): # optim: the optimizer class
74+ # detectron2 doesn't have full model gradient clipping now
75+ clip_norm_val = cfg .SOLVER .CLIP_GRADIENTS .CLIP_VALUE
76+ enable = (
77+ cfg .SOLVER .CLIP_GRADIENTS .ENABLED
78+ and cfg .SOLVER .CLIP_GRADIENTS .CLIP_TYPE == "full_model"
79+ and clip_norm_val > 0.0
80+ )
81+
82+ class FullModelGradientClippingOptimizer (optim ):
83+ def step (self , closure = None ):
84+ all_params = itertools .chain (* [x ["params" ] for x in self .param_groups ])
85+ torch .nn .utils .clip_grad_norm_ (all_params , clip_norm_val )
86+ super ().step (closure = closure )
87+
88+ return FullModelGradientClippingOptimizer if enable else optim
89+
10390 optimizer_type = cfg .SOLVER .OPTIMIZER
10491 if optimizer_type == "SGD" :
105- optimizer = torch .optim .SGD (params , cfg .SOLVER .BASE_LR , momentum = cfg .SOLVER .MOMENTUM )
92+ optimizer = maybe_add_full_model_gradient_clipping (torch .optim .SGD )(
93+ params , cfg .SOLVER .BASE_LR , momentum = cfg .SOLVER .MOMENTUM
94+ )
10695 elif optimizer_type == "ADAMW" :
107- optimizer = torch .optim .AdamW (params , cfg .SOLVER .BASE_LR )
96+ optimizer = maybe_add_full_model_gradient_clipping (torch .optim .AdamW )(
97+ params , cfg .SOLVER .BASE_LR
98+ )
10899 else :
109100 raise NotImplementedError (f"no optimizer type { optimizer_type } " )
110101 if not cfg .SOLVER .CLIP_GRADIENTS .CLIP_TYPE == "full_model" :
0 commit comments