Skip to content

Commit 592128b

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
Add find_unused_parameters
Summary: can be called when DDP fails to check what params cause the failure Differential Revision: D31463117 fbshipit-source-id: c12b6c82f16916d4ff9175f2d93d951d4f313350
1 parent 0e6702d commit 592128b

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

detectron2/utils/analysis.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
import typing
5+
from typing import Any, List
56
import fvcore
67
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
78
from torch import nn
@@ -151,3 +152,36 @@ def _wrapper_count_operators(
151152
ret = ret[0]
152153
model.train(old_train)
153154
return ret
155+
156+
157+
def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
158+
"""
159+
Given a model, find parameters that do not contribute
160+
to the loss.
161+
162+
Args:
163+
model: a model in training mode that returns losses
164+
inputs: argument or a tuple of arguments. Inputs of the model
165+
166+
Returns:
167+
list[str]: the name of unused parameters
168+
"""
169+
assert model.training
170+
for _, prm in model.named_parameters():
171+
prm.grad = None
172+
173+
if isinstance(inputs, tuple):
174+
losses = model(*inputs)
175+
else:
176+
losses = model(inputs)
177+
178+
if isinstance(losses, dict):
179+
losses = sum(losses.values())
180+
losses.backward()
181+
182+
unused: List[str] = []
183+
for name, prm in model.named_parameters():
184+
if prm.grad is None:
185+
unused.append(name)
186+
prm.grad = None
187+
return unused

tests/test_model_analysis.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import unittest
55
import torch
6+
from torch import nn
67

7-
from detectron2.utils.analysis import flop_count_operators, parameter_count
8+
from detectron2.utils.analysis import find_unused_parameters, flop_count_operators, parameter_count
89
from detectron2.utils.testing import get_model_no_weights
910

1011

@@ -42,3 +43,19 @@ def test_param_count(self):
4243
res = parameter_count(self.model)
4344
self.assertTrue(res[""], 41699936)
4445
self.assertTrue(res["backbone"], 26799296)
46+
47+
48+
class UnusedParamTest(unittest.TestCase):
49+
def test_unused(self):
50+
class TestMod(nn.Module):
51+
def __init__(self):
52+
super().__init__()
53+
self.fc1 = nn.Linear(10, 10)
54+
self.t = nn.Linear(10, 10)
55+
56+
def forward(self, x):
57+
return self.fc1(x).mean()
58+
59+
m = TestMod()
60+
ret = find_unused_parameters(m, torch.randn(10, 10))
61+
self.assertEqual(set(ret), {"t.weight", "t.bias"})

0 commit comments

Comments
 (0)