File tree Expand file tree Collapse file tree 2 files changed +21
-3
lines changed
rtdetrv2_pytorch/rtdetrv2 Expand file tree Collapse file tree 2 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -228,6 +228,18 @@ def reduce_dict(data, avg=True):
228228 return {k : v for k , v in zip (keys , values )}
229229
230230
231+ def all_reduce (data , op = tdist .ReduceOp .SUM ):
232+ """
233+ Run all_reduce on torch.Tensor data
234+ Args:
235+ data: torch.Tensor
236+ """
237+ world_size = get_world_size ()
238+ if world_size == 1 :
239+ return
240+ tdist .all_reduce (data , op = op )
241+
242+
231243def all_gather (data ):
232244 """
233245 Run all_gather on arbitrary picklable data (not necessarily tensors)
Original file line number Diff line number Diff line change @@ -61,12 +61,13 @@ def train_one_epoch(
6161 loss = sum (loss_dict .values ())
6262 scaler .scale (loss ).backward ()
6363
64+ scaler .unscale_ (optimizer )
6465 if max_norm > 0 :
65- scaler .unscale_ (optimizer )
6666 total_norm = torch .nn .utils .clip_grad_norm_ (
6767 model .parameters (), max_norm
6868 )
69- loss_dict ["grad_norm" ] = total_norm
69+ else :
70+ total_norm = torch .nn .utils .get_total_norm (model .parameters ())
7071
7172 scaler .step (optimizer )
7273 scaler .update ()
@@ -84,7 +85,8 @@ def train_one_epoch(
8485 total_norm = torch .nn .utils .clip_grad_norm_ (
8586 model .parameters (), max_norm
8687 )
87- loss_dict ["grad_norm" ] = total_norm
88+ else :
89+ total_norm = torch .nn .utils .get_total_norm (model .parameters ())
8890
8991 optimizer .step ()
9092
@@ -103,6 +105,10 @@ def train_one_epoch(
103105 print (loss_dict_reduced )
104106 sys .exit (1 )
105107
108+ # collect other values for logging
109+ dist_utils .all_reduce (total_norm )
110+ loss_dict_reduced ["grad_norm" ] = total_norm
111+
106112 metric_logger .update (loss = loss_value , ** loss_dict_reduced )
107113 metric_logger .update (lr = optimizer .param_groups [0 ]["lr" ])
108114
You can’t perform that action at this time.
0 commit comments