Skip to content

Commit 9839f94

Browse files
authored
fix(rtdetrv2): Fix incorrect method to inspect total norm during training stage (#9)
1 parent 4595c4e commit 9839f94

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

rtdetrv2_pytorch/rtdetrv2/misc/dist_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,18 @@ def reduce_dict(data, avg=True):
228228
return {k: v for k, v in zip(keys, values)}
229229

230230

231+
def all_reduce(data, op=tdist.ReduceOp.SUM):
232+
"""
233+
Run all_reduce on torch.Tensor data
234+
Args:
235+
data: torch.Tensor
236+
"""
237+
world_size = get_world_size()
238+
if world_size == 1:
239+
return
240+
tdist.all_reduce(data, op=op)
241+
242+
231243
def all_gather(data):
232244
"""
233245
Run all_gather on arbitrary picklable data (not necessarily tensors)

rtdetrv2_pytorch/rtdetrv2/solver/det_engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ def train_one_epoch(
6161
loss = sum(loss_dict.values())
6262
scaler.scale(loss).backward()
6363

64+
scaler.unscale_(optimizer)
6465
if max_norm > 0:
65-
scaler.unscale_(optimizer)
6666
total_norm = torch.nn.utils.clip_grad_norm_(
6767
model.parameters(), max_norm
6868
)
69-
loss_dict["grad_norm"] = total_norm
69+
else:
70+
total_norm = torch.nn.utils.get_total_norm(model.parameters())
7071

7172
scaler.step(optimizer)
7273
scaler.update()
@@ -84,7 +85,8 @@ def train_one_epoch(
8485
total_norm = torch.nn.utils.clip_grad_norm_(
8586
model.parameters(), max_norm
8687
)
87-
loss_dict["grad_norm"] = total_norm
88+
else:
89+
total_norm = torch.nn.utils.get_total_norm(model.parameters())
8890

8991
optimizer.step()
9092

@@ -103,6 +105,10 @@ def train_one_epoch(
103105
print(loss_dict_reduced)
104106
sys.exit(1)
105107

108+
# collect other values for logging
109+
dist_utils.all_reduce(total_norm)
110+
loss_dict_reduced["grad_norm"] = total_norm
111+
106112
metric_logger.update(loss=loss_value, **loss_dict_reduced)
107113
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
108114

0 commit comments

Comments
 (0)