Skip to content

Commit 98c0a11

Browse files
rohitgr7awaelchli
andauthored
Update docs for GradientAccumulationScheduler (#9891)
* update docs and add tests * update docs and add tests * Update pytorch_lightning/callbacks/gradient_accumulation_scheduler.py Co-authored-by: Rohit Gupta <[email protected]> * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 7eff003 commit 98c0a11

File tree

3 files changed

+130
-68
lines changed

3 files changed

+130
-68
lines changed

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Dict
2424

2525
from pytorch_lightning.callbacks.base import Callback
26+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2627

2728

2829
class GradientAccumulationScheduler(Callback):
@@ -32,6 +33,14 @@ class GradientAccumulationScheduler(Callback):
3233
Args:
3334
scheduling: scheduling in format {epoch: accumulation_factor}
3435
36+
Note:
37+
The argument scheduling is a dictionary. Each key represent an epoch and
38+
its associated accumulation factor value.
39+
Warning: Epoch are zero-indexed c.f it means if you want to change
40+
the accumulation factor after 4 epochs, set ``Trainer(accumulate_grad_batches={4: factor})``
41+
or ``GradientAccumulationScheduler(scheduling={4: factor})``.
42+
For more info check the example below.
43+
3544
Raises:
3645
TypeError:
3746
If ``scheduling`` is an empty ``dict``,
@@ -44,12 +53,13 @@ class GradientAccumulationScheduler(Callback):
4453
>>> from pytorch_lightning import Trainer
4554
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
4655
47-
# at epoch 5 start accumulating every 2 batches
48-
>>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
56+
# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
57+
because epoch (key) should be zero-indexed.
58+
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
4959
>>> trainer = Trainer(callbacks=[accumulator])
5060
5161
# alternatively, pass the scheduling dict directly to the Trainer
52-
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
62+
>>> trainer = Trainer(accumulate_grad_batches={4: 2})
5363
"""
5464

5565
def __init__(self, scheduling: Dict[int, int]):
@@ -58,9 +68,15 @@ def __init__(self, scheduling: Dict[int, int]):
5868
if not scheduling: # empty dict error
5969
raise TypeError("Empty dict cannot be interpreted correct")
6070

61-
for key in scheduling:
62-
if not isinstance(key, int) or not isinstance(scheduling[key], int):
63-
raise TypeError("All epoches and accumulation factor must be integers")
71+
if any(not isinstance(key, int) or key < 0 for key in scheduling):
72+
raise MisconfigurationException(
73+
f"Epoch should be an int greater than or equal to 0. Got {list(scheduling.keys())}."
74+
)
75+
76+
if any(not isinstance(value, int) or value < 1 for value in scheduling.values()):
77+
raise MisconfigurationException(
78+
f"Accumulation factor should be an int greater than 0. Got {list(scheduling.values())}."
79+
)
6480

6581
minimal_epoch = min(scheduling.keys())
6682
if minimal_epoch < 0:
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
from unittest.mock import patch
16+
17+
import pytest
18+
19+
from pytorch_lightning import Trainer
20+
from pytorch_lightning.callbacks import GradientAccumulationScheduler
21+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22+
from tests.helpers import BoringModel
23+
24+
25+
@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3))
26+
def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batches):
27+
with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
28+
model = BoringModel()
29+
trainer = Trainer(
30+
default_root_dir=tmpdir,
31+
limit_train_batches=20,
32+
limit_val_batches=1,
33+
max_epochs=1,
34+
weights_summary=None,
35+
accumulate_grad_batches=accumulate_grad_batches,
36+
)
37+
assert trainer.accumulate_grad_batches == accumulate_grad_batches
38+
trainer.fit(model)
39+
40+
assert sum(isinstance(cb, GradientAccumulationScheduler) for cb in trainer.callbacks) == 1
41+
assert sgd_zero_grad.call_count == math.ceil(trainer.limit_train_batches / accumulate_grad_batches)
42+
43+
44+
@pytest.mark.parametrize(
45+
["accumulate_grad_batches", "expected_call_count"],
46+
[
47+
({1: 2, 3: 4}, 10 + 5 + 5 + 3),
48+
({0: 2, 2: 1}, 5 + 5 + 10 + 10),
49+
],
50+
)
51+
def test_trainer_accumulate_grad_batches_dict_zero_grad(tmpdir, accumulate_grad_batches, expected_call_count):
52+
with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
53+
model = BoringModel()
54+
trainer = Trainer(
55+
default_root_dir=tmpdir,
56+
limit_train_batches=10,
57+
limit_val_batches=1,
58+
max_epochs=4,
59+
weights_summary=None,
60+
accumulate_grad_batches=accumulate_grad_batches,
61+
)
62+
assert trainer.accumulate_grad_batches == accumulate_grad_batches.get(0, 1)
63+
trainer.fit(model)
64+
65+
assert sum(isinstance(cb, GradientAccumulationScheduler) for cb in trainer.callbacks) == 1
66+
assert sgd_zero_grad.call_count == expected_call_count
67+
68+
69+
def test_trainer_accumulate_grad_batches_with_callback(tmpdir):
70+
with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
71+
model = BoringModel()
72+
trainer = Trainer(
73+
default_root_dir=tmpdir,
74+
limit_train_batches=10,
75+
limit_val_batches=1,
76+
max_epochs=4,
77+
weights_summary=None,
78+
callbacks=[GradientAccumulationScheduler({1: 2, 3: 4})],
79+
)
80+
assert trainer.accumulate_grad_batches == 1
81+
trainer.fit(model)
82+
83+
assert sum(isinstance(cb, GradientAccumulationScheduler) for cb in trainer.callbacks) == 1
84+
assert sgd_zero_grad.call_count == 10 + 5 + 5 + 3
85+
86+
87+
@pytest.mark.parametrize(
88+
"scheduling",
89+
[
90+
{1: 2, -3: 4},
91+
{0: 2, "2": 1},
92+
],
93+
)
94+
def test_invalid_keys_for_grad_accum_scheduler(scheduling):
95+
with pytest.raises(MisconfigurationException, match="Epoch should be an int"):
96+
_ = GradientAccumulationScheduler(scheduling=scheduling)
97+
98+
99+
@pytest.mark.parametrize(
100+
"scheduling",
101+
[
102+
{1: 0, 3: 4},
103+
{0: 2, 2: "2"},
104+
],
105+
)
106+
def test_invalid_values_for_grad_accum_scheduler(scheduling):
107+
with pytest.raises(MisconfigurationException, match="Accumulation factor should be an int"):
108+
_ = GradientAccumulationScheduler(scheduling=scheduling)

tests/trainer/test_trainer.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -174,68 +174,6 @@ def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
174174
assert not failed, "Model should be loaded due to strict=False."
175175

176176

177-
@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3))
178-
def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batches):
179-
with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
180-
model = BoringModel()
181-
trainer = Trainer(
182-
default_root_dir=tmpdir,
183-
limit_train_batches=20,
184-
limit_val_batches=1,
185-
max_epochs=1,
186-
weights_summary=None,
187-
accumulate_grad_batches=accumulate_grad_batches,
188-
)
189-
assert trainer.accumulate_grad_batches == accumulate_grad_batches
190-
trainer.fit(model)
191-
192-
assert sum(isinstance(cb, GradientAccumulationScheduler) for cb in trainer.callbacks) == 1
193-
assert sgd_zero_grad.call_count == math.ceil(trainer.limit_train_batches / accumulate_grad_batches)
194-
195-
196-
@pytest.mark.parametrize(
197-
["accumulate_grad_batches", "expected_call_count"],
198-
[
199-
({1: 2, 3: 4}, 10 + 5 + 5 + 3),
200-
({0: 2, 2: 1}, 5 + 5 + 10 + 10),
201-
],
202-
)
203-
def test_trainer_accumulate_grad_batches_dict_zero_grad(tmpdir, accumulate_grad_batches, expected_call_count):
204-
with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
205-
model = BoringModel()
206-
trainer = Trainer(
207-
default_root_dir=tmpdir,
208-
limit_train_batches=10,
209-
limit_val_batches=1,
210-
max_epochs=4,
211-
weights_summary=None,
212-
accumulate_grad_batches=accumulate_grad_batches,
213-
)
214-
assert trainer.accumulate_grad_batches == accumulate_grad_batches.get(0, 1)
215-
trainer.fit(model)
216-
217-
assert sum(isinstance(cb, GradientAccumulationScheduler) for cb in trainer.callbacks) == 1
218-
assert sgd_zero_grad.call_count == expected_call_count
219-
220-
221-
def test_trainer_accumulate_grad_batches_with_callback(tmpdir):
222-
with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
223-
model = BoringModel()
224-
trainer = Trainer(
225-
default_root_dir=tmpdir,
226-
limit_train_batches=10,
227-
limit_val_batches=1,
228-
max_epochs=4,
229-
weights_summary=None,
230-
callbacks=[GradientAccumulationScheduler({1: 2, 3: 4})],
231-
)
232-
assert trainer.accumulate_grad_batches == 1
233-
trainer.fit(model)
234-
235-
assert sum(isinstance(cb, GradientAccumulationScheduler) for cb in trainer.callbacks) == 1
236-
assert sgd_zero_grad.call_count == 10 + 5 + 5 + 3
237-
238-
239177
def test_trainer_accumulate_grad_batches_incorrect_value(tmpdir):
240178
with pytest.raises(MisconfigurationException, match=".*should be an int or a dict.*"):
241179
Trainer(default_root_dir=tmpdir, accumulate_grad_batches=(2, 5))

0 commit comments

Comments
 (0)