Skip to content

Commit 2acff1c

Browse files
carmoccalexierule
authored andcommitted
Avoid changing the current cudnn.benchmark value (#13154)
1 parent 3c06cd8 commit 2acff1c

File tree

5 files changed

+50
-28
lines changed

5 files changed

+50
-28
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Added
1111

1212
- Added all DDP params to be exposed through hpu parallel strategy ([#13067](https://github.com/PyTorchLightning/pytorch-lightning/pull/13067))
13+
14+
### Changed
15+
16+
- Keep `torch.backends.cudnn.benchmark=False` by default (unlike in v1.6.{0-4}) after speed and memory problems depending on the data used. Please consider tuning `Trainer(benchmark)` manually. ([#13154](https://github.com/PyTorchLightning/pytorch-lightning/pull/13154))
17+
- Prevent modification of `torch.backends.cudnn.benchmark` when `Trainer(benchmark=...)` is not set ([#13154](https://github.com/PyTorchLightning/pytorch-lightning/pull/13154))
18+
1319
### Fixed
1420

1521
- Fixed an issue causing zero-division error for empty dataloaders ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))

docs/source/common/trainer.rst

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,21 +437,24 @@ benchmark
437437

438438
|
439439
440-
Defaults to ``True`` if :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is not set.
441-
This flag sets the ``torch.backends.cudnn.deterministic`` flag. You can read more about its impact
440+
The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for
441+
``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set).
442+
If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``.
443+
You can read more about the interaction of ``torch.backends.cudnn.benchmark`` and ``torch.backends.cudnn.deterministic``
442444
`here <https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking>`__
443445

444-
This is likely to increase the speed of your system if your input sizes don't change. However, if they do, then it
445-
might make your system slower. The CUDNN auto-tuner will try to find the best algorithm for the hardware when a new
446-
input size is encountered. Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`__.
446+
Setting this flag to ``True`` can increase the speed of your system if your input sizes don't
447+
change. However, if they do, then it might make your system slower. The CUDNN auto-tuner will try to find the best
448+
algorithm for the hardware when a new input size is encountered. This might also increase the memory usage.
449+
Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`__.
447450

448451
Example::
449452

450-
# defaults to True if not deterministic (which is False by default)
451-
trainer = Trainer()
453+
# Will use whatever the current value for torch.backends.cudnn.benchmark, normally False
454+
trainer = Trainer(benchmark=None) # default
452455

453456
# you can overwrite the value
454-
trainer = Trainer(benchmark=False)
457+
trainer = Trainer(benchmark=True)
455458

456459
deterministic
457460
^^^^^^^^^^^^^

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,19 @@ def __init__(
143143
B. Strategy > Accelerator/precision/plugins
144144
C. TODO When multiple flag set to the same thing
145145
"""
146-
if benchmark and deterministic:
147-
rank_zero_warn(
148-
"You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
149-
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
150-
)
151-
self.benchmark = not deterministic if benchmark is None else benchmark
146+
if deterministic:
147+
if benchmark is None:
148+
# Set benchmark to False to ensure determinism
149+
benchmark = False
150+
elif benchmark:
151+
rank_zero_warn(
152+
"You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
153+
" torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
154+
)
152155
# TODO: move to gpu accelerator
153-
torch.backends.cudnn.benchmark = self.benchmark
156+
if benchmark is not None:
157+
torch.backends.cudnn.benchmark = benchmark
158+
self.benchmark = torch.backends.cudnn.benchmark
154159
self.replace_sampler_ddp = replace_sampler_ddp
155160
self._init_deterministic(deterministic)
156161

@@ -211,10 +216,10 @@ def __init__(
211216
# 6. Instantiate Strategy - Part 2
212217
self._lazy_init_strategy()
213218

214-
def _init_deterministic(self, deterministic: bool) -> None:
215-
self.deterministic = deterministic
216-
torch.use_deterministic_algorithms(deterministic)
217-
if deterministic:
219+
def _init_deterministic(self, deterministic: Optional[bool]) -> None:
220+
self.deterministic = deterministic or False # default to False if not set
221+
torch.use_deterministic_algorithms(self.deterministic)
222+
if self.deterministic:
218223
# fixing non-deterministic part of horovod
219224
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
220225
os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
resume_from_checkpoint: Optional[Union[Path, str]] = None,
175175
profiler: Optional[Union[Profiler, str]] = None,
176176
benchmark: Optional[bool] = None,
177-
deterministic: bool = False,
177+
deterministic: Optional[bool] = None,
178178
reload_dataloaders_every_n_epochs: int = 0,
179179
auto_lr_find: Union[bool, str] = False,
180180
replace_sampler_ddp: bool = True,
@@ -229,9 +229,11 @@ def __init__(
229229
that only one process at a time can access them.
230230
Default: ``False``.
231231
232-
benchmark: Sets ``torch.backends.cudnn.benchmark``.
233-
Defaults to ``True`` if :paramref:`~pytorch_lightning.trainer.trainer.Trainer.deterministic`
234-
is ``False``. Overwrite to manually set a different value. Default: ``None``.
232+
benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
233+
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
234+
(``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set
235+
to ``True``, this will default to ``False``. Override to manually set a different value.
236+
Default: ``None``.
235237
236238
callbacks: Add a callback or list of callbacks.
237239
Default: ``None``.
@@ -260,7 +262,8 @@ def __init__(
260262
Default: ``False``.
261263
262264
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
263-
Default: ``False``.
265+
If not set, defaults to ``False``.
266+
Default: ``None``.
264267
265268
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
266269
based on the accelerator type.

tests/trainer/test_trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -637,27 +637,32 @@ def test_trainer_max_steps_accumulate_batches(tmpdir):
637637
assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps"
638638

639639

640+
@pytest.mark.parametrize("cudnn_benchmark", (False, True))
640641
@pytest.mark.parametrize(
641642
["benchmark_", "deterministic", "expected"],
642643
[
643-
(None, False, True),
644+
(None, False, None),
644645
(None, True, False),
646+
(None, None, None),
645647
(True, False, True),
646648
(True, True, True),
647-
(False, True, False),
649+
(True, None, True),
648650
(False, False, False),
651+
(False, True, False),
652+
(False, None, False),
649653
],
650654
)
651-
def test_benchmark_option(benchmark_, deterministic, expected):
655+
def test_benchmark_option(cudnn_benchmark, benchmark_, deterministic, expected):
652656
"""Verify benchmark option."""
653-
654657
original_val = torch.backends.cudnn.benchmark
655658

659+
torch.backends.cudnn.benchmark = cudnn_benchmark
656660
if benchmark_ and deterministic:
657661
with pytest.warns(UserWarning, match="You passed `deterministic=True` and `benchmark=True`"):
658662
trainer = Trainer(benchmark=benchmark_, deterministic=deterministic)
659663
else:
660664
trainer = Trainer(benchmark=benchmark_, deterministic=deterministic)
665+
expected = cudnn_benchmark if expected is None else expected
661666
assert torch.backends.cudnn.benchmark == expected
662667
assert trainer._accelerator_connector.benchmark == expected
663668

0 commit comments

Comments
 (0)