Skip to content

Commit a2a998a

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add a convient scheduler just for warmup
Summary: it's a special case of CompositeParamScheduler Reviewed By: theschnitz Differential Revision: D26246409 fbshipit-source-id: afb5d49d99c8be237c59a464c27184763b4db150
1 parent 4506882 commit a2a998a

File tree

4 files changed

+65
-53
lines changed

4 files changed

+65
-53
lines changed

detectron2/solver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
3-
from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR, LRMultiplier
3+
from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR, LRMultiplier, WarmupParamScheduler
44

55
__all__ = [k for k in globals().keys() if not k.startswith("_")]

detectron2/solver/build.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,11 @@
33
from enum import Enum
44
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
55
import torch
6-
from fvcore.common.param_scheduler import (
7-
CompositeParamScheduler,
8-
ConstantParamScheduler,
9-
CosineParamScheduler,
10-
LinearParamScheduler,
11-
MultiStepParamScheduler,
12-
)
6+
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
137

148
from detectron2.config import CfgNode
159

16-
from .lr_scheduler import LRMultiplier
10+
from .lr_scheduler import LRMultiplier, WarmupParamScheduler
1711

1812
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
1913
_GradientClipper = Callable[[_GradientClipperInput], None]
@@ -221,18 +215,10 @@ def build_lr_scheduler(
221215
else:
222216
raise ValueError("Unknown LR scheduler: {}".format(name))
223217

224-
# Add warmup
225-
warmup_method = cfg.SOLVER.WARMUP_METHOD
226-
if warmup_method == "constant":
227-
warmup = ConstantParamScheduler(cfg.SOLVER.WARMUP_FACTOR)
228-
elif warmup_method == "linear":
229-
warmup = LinearParamScheduler(cfg.SOLVER.WARMUP_FACTOR, 1.0)
230-
else:
231-
raise ValueError("Unknown warmup method: {}".format(warmup_method))
232-
warmup_ratio = cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER
233-
sched = CompositeParamScheduler(
234-
[warmup, sched],
235-
interval_scaling=["rescaled", "fixed"],
236-
lengths=[warmup_ratio, 1 - warmup_ratio],
218+
sched = WarmupParamScheduler(
219+
sched,
220+
cfg.SOLVER.WARMUP_FACTOR,
221+
cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER,
222+
cfg.SOLVER.WARMUP_METHOD,
237223
)
238224
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)

detectron2/solver/lr_scheduler.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,51 @@
44
from bisect import bisect_right
55
from typing import List
66
import torch
7-
from fvcore.common.param_scheduler import ParamScheduler
7+
from fvcore.common.param_scheduler import (
8+
CompositeParamScheduler,
9+
ConstantParamScheduler,
10+
LinearParamScheduler,
11+
ParamScheduler,
12+
)
813

914
logger = logging.getLogger(__name__)
1015

1116

17+
class WarmupParamScheduler(CompositeParamScheduler):
18+
"""
19+
Add an initial warmup stage to another scheduler.
20+
"""
21+
22+
def __init__(
23+
self,
24+
scheduler: ParamScheduler,
25+
warmup_factor: float,
26+
warmup_length: float,
27+
warmup_method: str = "linear",
28+
):
29+
"""
30+
Args:
31+
scheduler: warmup will be added at the beginning of this scheduler
32+
warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001
33+
warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
34+
training, e.g. 0.01
35+
warmup_method: one of "linear" or "constant"
36+
"""
37+
end_value = scheduler(warmup_length) # the value to reach when warmup ends
38+
start_value = warmup_factor * scheduler(0.0)
39+
if warmup_method == "constant":
40+
warmup = ConstantParamScheduler(start_value)
41+
elif warmup_method == "linear":
42+
warmup = LinearParamScheduler(start_value, end_value)
43+
else:
44+
raise ValueError("Unknown warmup method: {}".format(warmup_method))
45+
super().__init__(
46+
[warmup, scheduler],
47+
interval_scaling=["rescaled", "fixed"],
48+
lengths=[warmup_length, 1 - warmup_length],
49+
)
50+
51+
1252
class LRMultiplier(torch.optim.lr_scheduler._LRScheduler):
1353
"""
1454
A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
@@ -25,15 +65,12 @@ class LRMultiplier(torch.optim.lr_scheduler._LRScheduler):
2565
::
2666
LRMultiplier(
2767
opt,
28-
CompositeParamScheduler([
29-
LinearParamScheduler(0.001, 1), # warmup
68+
WarmupParamScheduler(
3069
MultiStepParamScheduler(
3170
[1, 0.1, 0.01],
3271
milestones=[60000, 80000],
3372
num_updates=90000,
34-
)],
35-
interval_scaling=["rescaled", "fixed"],
36-
lengths=[100 / 90000, 89900 / 90000],
73+
), 0.001, 100 / 90000
3774
),
3875
max_iter=90000
3976
)

tests/test_scheduler.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,25 @@
44
import numpy as np
55
from unittest import TestCase
66
import torch
7-
from fvcore.common.param_scheduler import (
8-
CompositeParamScheduler,
9-
CosineParamScheduler,
10-
LinearParamScheduler,
11-
MultiStepParamScheduler,
12-
)
7+
from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
138
from torch import nn
149

15-
from detectron2.solver import LRMultiplier
10+
from detectron2.solver import LRMultiplier, WarmupParamScheduler
1611

1712

1813
class TestScheduler(TestCase):
1914
def test_warmup_multistep(self):
2015
p = nn.Parameter(torch.zeros(0))
2116
opt = torch.optim.SGD([p], lr=5)
2217

23-
multiplier = CompositeParamScheduler(
24-
[
25-
LinearParamScheduler(0.001, 1), # warmup
26-
MultiStepParamScheduler(
27-
[1, 0.1, 0.01, 0.001],
28-
milestones=[10, 15, 20],
29-
num_updates=30,
30-
),
31-
],
32-
interval_scaling=["rescaled", "fixed"],
33-
lengths=[5 / 30, 25 / 30],
18+
multiplier = WarmupParamScheduler(
19+
MultiStepParamScheduler(
20+
[1, 0.1, 0.01, 0.001],
21+
milestones=[10, 15, 20],
22+
num_updates=30,
23+
),
24+
0.001,
25+
5 / 30,
3426
)
3527
sched = LRMultiplier(opt, multiplier, 30)
3628
# This is an equivalent of:
@@ -53,13 +45,10 @@ def test_warmup_multistep(self):
5345
def test_warmup_cosine(self):
5446
p = nn.Parameter(torch.zeros(0))
5547
opt = torch.optim.SGD([p], lr=5)
56-
multiplier = CompositeParamScheduler(
57-
[
58-
LinearParamScheduler(0.001, 1), # warmup
59-
CosineParamScheduler(1, 0),
60-
],
61-
interval_scaling=["rescaled", "fixed"],
62-
lengths=[5 / 30, 25 / 30],
48+
multiplier = WarmupParamScheduler(
49+
CosineParamScheduler(1, 0),
50+
0.001,
51+
5 / 30,
6352
)
6453
sched = LRMultiplier(opt, multiplier, 30)
6554

0 commit comments

Comments
 (0)