Skip to content

Commit 4b539e4

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
ues ParamScheduler from fvcore
Summary: Pull Request resolved: #2585 Can replace our existing LR schedulers Reviewed By: theschnitz Differential Revision: D26220618 fbshipit-source-id: e3fd7a4427bcd3506554292764bb362a39618a9f
1 parent 81b9cad commit 4b539e4

File tree

7 files changed

+204
-19
lines changed

7 files changed

+204
-19
lines changed

detectron2/engine/hooks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,10 @@ class LRScheduler(HookBase):
206206
def __init__(self, optimizer=None, scheduler=None):
207207
"""
208208
Args:
209-
No args needed. Will obtain optimizer and scheduler from trainer.
209+
optimizer (torch.optim.Optimizer):
210+
scheduler (torch.optim.LRScheduler):
211+
212+
If any argument is not given, will try to obtain it from the trainer.
210213
"""
211214
self._optimizer = optimizer
212215
self._scheduler = scheduler

detectron2/solver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
3-
from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR
3+
from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR, LRMultiplier
44

55
__all__ = [k for k in globals().keys() if not k.startswith("_")]

detectron2/solver/build.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
from enum import Enum
44
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
55
import torch
6+
from fvcore.common.param_scheduler import (
7+
CompositeParamScheduler,
8+
ConstantParamScheduler,
9+
CosineParamScheduler,
10+
LinearParamScheduler,
11+
MultiStepParamScheduler,
12+
)
613

714
from detectron2.config import CfgNode
815

9-
from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR
16+
from .lr_scheduler import LRMultiplier
1017

1118
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
1219
_GradientClipper = Callable[[_GradientClipperInput], None]
@@ -41,7 +48,7 @@ def _generate_optimizer_class_with_gradient_clipping(
4148
optimizer: Type[torch.optim.Optimizer],
4249
*,
4350
per_param_clipper: Optional[_GradientClipper] = None,
44-
global_clipper: Optional[_GradientClipper] = None
51+
global_clipper: Optional[_GradientClipper] = None,
4552
) -> Type[torch.optim.Optimizer]:
4653
"""
4754
Dynamically creates a new type that inherits the type of a given instance
@@ -202,22 +209,30 @@ def build_lr_scheduler(
202209
Build a LR scheduler from config.
203210
"""
204211
name = cfg.SOLVER.LR_SCHEDULER_NAME
212+
205213
if name == "WarmupMultiStepLR":
206-
return WarmupMultiStepLR(
207-
optimizer,
208-
cfg.SOLVER.STEPS,
209-
cfg.SOLVER.GAMMA,
210-
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
211-
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
212-
warmup_method=cfg.SOLVER.WARMUP_METHOD,
214+
sched = MultiStepParamScheduler(
215+
values=[cfg.SOLVER.GAMMA ** k for k in range(len(cfg.SOLVER.STEPS) + 1)],
216+
milestones=cfg.SOLVER.STEPS,
217+
num_updates=cfg.SOLVER.MAX_ITER,
213218
)
214219
elif name == "WarmupCosineLR":
215-
return WarmupCosineLR(
216-
optimizer,
217-
cfg.SOLVER.MAX_ITER,
218-
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
219-
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
220-
warmup_method=cfg.SOLVER.WARMUP_METHOD,
221-
)
220+
sched = CosineParamScheduler(1, 0)
222221
else:
223222
raise ValueError("Unknown LR scheduler: {}".format(name))
223+
224+
# Add warmup
225+
warmup_method = cfg.SOLVER.WARMUP_METHOD
226+
if warmup_method == "constant":
227+
warmup = ConstantParamScheduler(cfg.SOLVER.WARMUP_FACTOR)
228+
elif warmup_method == "linear":
229+
warmup = LinearParamScheduler(cfg.SOLVER.WARMUP_FACTOR, 1.0)
230+
else:
231+
raise ValueError("Unknown warmup method: {}".format(warmup_method))
232+
warmup_ratio = cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER
233+
sched = CompositeParamScheduler(
234+
[warmup, sched],
235+
interval_scaling=["rescaled", "fixed"],
236+
lengths=[warmup_ratio, 1 - warmup_ratio],
237+
)
238+
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)

detectron2/solver/lr_scheduler.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,88 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import logging
23
import math
34
from bisect import bisect_right
45
from typing import List
56
import torch
7+
from fvcore.common.param_scheduler import ParamScheduler
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class LRMultiplier(torch.optim.lr_scheduler._LRScheduler):
13+
"""
14+
A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
15+
learning rate of each param in the optimizer.
16+
Every step, the learning rate of each parameter becomes its initial value
17+
multiplied by the output of the given :class:`ParamScheduler`.
18+
19+
The absolute learning rate value of each parameter can be different.
20+
This scheduler can be used as long as the relative scale among them do
21+
not change during training.
22+
23+
Examples:
24+
25+
::
26+
LRMultiplier(
27+
opt,
28+
CompositeParamScheduler([
29+
LinearParamScheduler(0.001, 1), # warmup
30+
MultiStepParamScheduler(
31+
[1, 0.1, 0.01],
32+
milestones=[60000, 80000],
33+
num_updates=90000,
34+
)],
35+
interval_scaling=["rescaled", "fixed"],
36+
lengths=[100 / 90000, 89900 / 90000],
37+
),
38+
max_iter=90000
39+
)
40+
"""
41+
42+
# NOTES: in the most general case, every LR can use its own scheduler.
43+
# Supporting this requires interaction with the optimizer when its parameter
44+
# group is initialized. For example, classyvision implements its own optimizer
45+
# that allows different schedulers for every parameter group.
46+
# To avoid this complexity, we use this class to support the most common cases
47+
# where the relative scale among all LRs stay unchanged during training. In this
48+
# case we only need a total of one scheduler that defines the relative LR multiplier.
49+
50+
def __init__(
51+
self,
52+
optimizer: torch.optim.Optimizer,
53+
multiplier: ParamScheduler,
54+
max_iter: int,
55+
last_iter: int = -1,
56+
):
57+
"""
58+
Args:
59+
optimizer, last_iter: See ``torch.optim.lr_scheduler._LRScheduler``.
60+
``last_iter`` is the same as ``last_epoch``.
61+
multiplier: a fvcore ParamScheduler that defines the multiplier on
62+
every LR of the optimizer
63+
max_iter: the total number of training iterations
64+
"""
65+
if not isinstance(multiplier, ParamScheduler):
66+
raise ValueError(
67+
"_LRMultiplier(multiplier=) must be an instance of fvcore "
68+
f"ParamScheduler. Got {multiplier} instead."
69+
)
70+
self._multiplier = multiplier
71+
self._max_iter = max_iter
72+
super().__init__(optimizer, last_epoch=last_iter)
73+
74+
def state_dict(self):
75+
# fvcore schedulers are stateless. Only keep pytorch scheduler states
76+
return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
77+
78+
def get_lr(self) -> List[float]:
79+
multiplier = self._multiplier(self.last_epoch / self._max_iter)
80+
return [base_lr * multiplier for base_lr in self.base_lrs]
81+
82+
83+
"""
84+
Content below is no longer needed!
85+
"""
686

787
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
888
# only on epoch boundaries. We typically use iteration based schedules instead.
@@ -24,6 +104,9 @@ def __init__(
24104
warmup_method: str = "linear",
25105
last_epoch: int = -1,
26106
):
107+
logger.warning(
108+
"WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
109+
)
27110
if not list(milestones) == sorted(milestones):
28111
raise ValueError(
29112
"Milestones should be a list of" " increasing integers. Got {}", milestones
@@ -59,6 +142,9 @@ def __init__(
59142
warmup_method: str = "linear",
60143
last_epoch: int = -1,
61144
):
145+
logger.warning(
146+
"WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
147+
)
62148
self.max_iters = max_iters
63149
self.warmup_factor = warmup_factor
64150
self.warmup_iters = warmup_iters

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
277277
"StandardAugInput",
278278
"build_batch_data_loader",
279279
"draw_panoptic_seg_predictions",
280+
"WarmupCosineLR",
281+
"WarmupMultiStepLR",
280282
}
281283
try:
282284
if name in HIDDEN or (

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def get_model_zoo_configs() -> List[str]:
211211
"matplotlib",
212212
"tqdm>4.29.0",
213213
"tensorboard",
214-
"fvcore>=0.1.2,<0.1.3", # required like this to make it pip installable
214+
"fvcore>=0.1.3,<0.1.4", # required like this to make it pip installable
215215
"iopath>=0.1.2",
216216
"pycocotools>=2.0.2", # corresponds to https://github.com/ppwwyyxx/cocoapi
217217
"future", # used by caffe2

tests/test_scheduler.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
import math
4+
import numpy as np
5+
from unittest import TestCase
6+
import torch
7+
from fvcore.common.param_scheduler import (
8+
CompositeParamScheduler,
9+
CosineParamScheduler,
10+
LinearParamScheduler,
11+
MultiStepParamScheduler,
12+
)
13+
from torch import nn
14+
15+
from detectron2.solver import LRMultiplier
16+
17+
18+
class TestScheduler(TestCase):
19+
def test_warmup_multistep(self):
20+
p = nn.Parameter(torch.zeros(0))
21+
opt = torch.optim.SGD([p], lr=5)
22+
23+
multiplier = CompositeParamScheduler(
24+
[
25+
LinearParamScheduler(0.001, 1), # warmup
26+
MultiStepParamScheduler(
27+
[1, 0.1, 0.01, 0.001],
28+
milestones=[10, 15, 20],
29+
num_updates=30,
30+
),
31+
],
32+
interval_scaling=["rescaled", "fixed"],
33+
lengths=[5 / 30, 25 / 30],
34+
)
35+
sched = LRMultiplier(opt, multiplier, 30)
36+
# This is an equivalent of:
37+
# sched = WarmupMultiStepLR(
38+
# opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5)
39+
40+
p.sum().backward()
41+
opt.step()
42+
43+
lrs = [0.005]
44+
for _ in range(30):
45+
sched.step()
46+
lrs.append(opt.param_groups[0]["lr"])
47+
self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001]))
48+
self.assertTrue(np.allclose(lrs[5:10], 5.0))
49+
self.assertTrue(np.allclose(lrs[10:15], 0.5))
50+
self.assertTrue(np.allclose(lrs[15:20], 0.05))
51+
self.assertTrue(np.allclose(lrs[20:], 0.005))
52+
53+
def test_warmup_cosine(self):
54+
p = nn.Parameter(torch.zeros(0))
55+
opt = torch.optim.SGD([p], lr=5)
56+
multiplier = CompositeParamScheduler(
57+
[
58+
LinearParamScheduler(0.001, 1), # warmup
59+
CosineParamScheduler(1, 0),
60+
],
61+
interval_scaling=["rescaled", "fixed"],
62+
lengths=[5 / 30, 25 / 30],
63+
)
64+
sched = LRMultiplier(opt, multiplier, 30)
65+
66+
p.sum().backward()
67+
opt.step()
68+
self.assertEqual(opt.param_groups[0]["lr"], 0.005)
69+
lrs = [0.005]
70+
71+
for _ in range(30):
72+
sched.step()
73+
lrs.append(opt.param_groups[0]["lr"])
74+
for idx, lr in enumerate(lrs):
75+
expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30))
76+
if idx >= 5:
77+
self.assertAlmostEqual(lr, expected_cosine)
78+
else:
79+
self.assertNotAlmostEqual(lr, expected_cosine)

0 commit comments

Comments
 (0)