-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
Found a discrepancy between a continued run after checkpointing, and restoring from checkpoint
Observation:
training_batch / val_loop ordering upon checkpoint restoration not the same as original run after checkpoint saving.
There are still the same amount of train steps, but the validation loops are interleaved at a single step later, which can cause the restored run to end up with one less validation loop (see colab)
Assumption / expectation:
Zero difference between a training run after a checkpoint and a run continued from said checkpoint
Investigation so far:
Im new to some of this lightning code, but IIUC:
Key:
TrainingEpochLoop
's self.batch_progress.increment_completed()
is called after "on_train_batch_end"
hooks, the latter kicking off checkpoint saving.
- upon restoring, the
TrainingEpochLoop.batch_progress.current.reset_on_restart()
will reset theready
back tocompleted
- yet the
global_step
, which refers toTrainingEpochLoop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed
, has beenincrement_completed()
(called withinTrainingEpochLoop.batch_loop.run
) and thus upon restoring,..optimizer.step.total.ready
is set to an up to dateoptimizer.step.total.completed
, out of sync with the above - [simplification] in "
val_check_interval
mode", validation is triggered whenTrainingEpochLoop.batch_progress.current.ready % val_check_interval == 0
(throughTrainingEpochLoop.on_advance_end
->TrainingEpochLoop._should_check_val_fx
- combining the above three, the same
batch_progress.current
ready
/completed
counter for the continued and restored runs, end up aligned with differentglobal_step
s, and hence validation triggers at differentglobal_step
s
Another observation:
The following if statement seems to allow for a zero-difference restart, except that just like 4. above, _should_check_val_fx
wouldnt trigger where in the original run on the checkpointing step it did (although there called in on_advance_end
). Not sure if the original intention of this snippet included the current scope
class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
...
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
...
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
# skip training and run validation in `on_advance_end`
return
PR's relevant to this line:
Potential impact:
Assuming not too worrisome for the more default Lightning use cases:
- With
val_check_interval
>> 3 (colab example = 3), or that turned off relying instead oncheck_val_every_n_epoch
However, in theory it can influence all of the following:
- no 1:1 deterministic reproducibility
- affect the latest/best validation loss
- affects any code flow / decision making based on that
- cause a "different usage order" of rngs (<- how I initially caught the issue, even with correctly restored rng states, if both validation and training steps use one theyll each end up with different random numbers as compared to the continued run)
- other
To Reproduce
customized google colab bug_report_model.ipynb
with same observation on BoringModel
Expected behavior
Zero difference between a training run continued after a checkpoint and a run continued from said checkpoint
Environment
Note:
- The below is from original investigation in our own code base, with pytorch lightning
v1.6.4
. - The environment details from the BoringModel's reproduction are listed in the colab, with pytorch lighting
v1.7.4
- I also browsed through the
master
branch last weeks and the relevant code seems unchanged
Details
-
CUDA:
- GPU:
- NVIDIA RTX A4000
- NVIDIA RTX A4000
- NVIDIA RTX A4000
- NVIDIA RTX A4000
- available: True
- version: 11.0
-
Lightning:
- efficientnet-pytorch: 0.7.1
- pytorch-lightning: 1.6.4
- torch: 1.11.0.post1103
- torchmetrics: 0.7.0
- torchvision: 0.12.0a1110.post1103
-
Packages:
- absl-py: 0.15.0
- adal: 1.2.7
- adlfs: 2021.10.0
- aiohttp: 3.7.4
- applicationinsights: 0.11.10
- argcomplete: 1.12.3
- async-timeout: 3.0.1
- attrdict: 2.0.0
- attrs: 21.1.0
- av: 8.0.3
- azure-cli-core: 2.38.0
- azure-cli-telemetry: 1.0.6
- azure-common: 1.1.27
- azure-core: 1.20.0
- azure-datalake-store: 0.0.52
- azure-identity: 1.10.0
- azure-keyvault-secrets: 4.2.0
- azure-mgmt-core: 1.2.2
- azure-storage-blob: 12.11.0
- backcall: 0.2.0
- backoff: 1.10.0
- bcrypt: 3.2.0
- cachetools: 4.2.2
- certifi: 2020.12.5
- cffi: 1.14.5
- chardet: 3.0.4
- charset-normalizer: 2.0.12
- click: 7.1.2
- confluent-kafka: 1.7.0
- cryptography: 3.4.8
- cycler: 0.10.0
- datadog: 0.44.0
- decorator: 5.0.7
- deepdiff: 5.5.0
- deltalake: 0.5.8
- docker-pycreds: 0.4.0
- efficientnet-pytorch: 0.7.1
- einops: 0.4.1
- filelock: 3.7.1
- fonttools: 4.37.1
- frozendict: 2.3.2
- fsspec: 2022.1.0
- gitdb: 4.0.7
- gitpython: 3.1.14
- google-auth: 1.30.0
- google-auth-oauthlib: 0.4.4
- grpcio: 1.37.1
- htmlmin: 0.1.12
- humanfriendly: 10.0
- idna: 2.10
- imagehash: 4.2.1
- inplace-abn: 1.1.0a1110.post1103
- ipdb: 0.13.9
- ipython: 7.23.1
- isodate: 0.6.0
- jedi: 0.18.0
- jinja2: 3.1.2
- jmespath: 0.10.0
- joblib: 1.0.1
- kafka-python: 2.0.2
- kiwisolver: 1.3.1
- knack: 0.9.0
- markdown: 3.3.4
- markupsafe: 2.0.1
- matplotlib: 3.5.3
- matplotlib-inline: 0.1.2
- methodtools: 0.1.2
- missingno: 0.5.0
- msal: 1.16.0
- msal-extensions: 0.3.0
- msrest: 0.6.21
- msrestazure: 0.6.4
- multidict: 5.1.0
- multimethod: 1.6
- networkx: 2.5.1
- numpy: 1.22.4
- oauthlib: 3.1.0
- opencv-python: 4.4.0.44
- ordered-set: 4.0.2
- packaging: 21.3
- pandas: 1.4.3
- pandas-profiling: 3.1.0
- paramiko: 2.7.2
- parso: 0.8.2
- pathtools: 0.1.2
- pexpect: 4.8.0
- phik: 0.12.0
- pickleshare: 0.7.5
- pillow: 9.2.0
- pip: 22.0.3
- pkginfo: 1.7.0
- polyline: 1.4.0
- portalocker: 1.7.1
- prometheus-client: 0.8.0
- promise: 2.3
- prompt-toolkit: 2.0.10
- protobuf: 3.15.8
- psutil: 5.9.1
- psycopg2: 2.8.3
- ptyprocess: 0.7.0
- py: 1.10.0
- py3nvml: 0.2.7
- pyarrow: 9.0.0
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycparser: 2.20
- pydantic: 1.8.2
- pydeprecate: 0.3.1
- pygame: 2.1.2
- pygments: 2.9.0
- pyjwt: 1.7.1
- pynacl: 1.4.0
- pyntcloud: 0.1.6
- pyopenssl: 20.0.1
- pyparsing: 2.4.7
- pyquaternion: 0.9.9
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- python-json-logger: 2.0.2
- pytorch-lightning: 1.6.4
- pytz: 2022.1
- pywavelets: 1.1.1
- pyyaml: 6.0
- qrcode: 6.1
- requests: 2.27.1
- requests-oauthlib: 1.3.0
- retry: 0.9.2
- rsa: 4.7.2
- runai: 0.3.0
- scipy: 1.6.2
- seaborn: 0.11.2
- semver: 2.13.0
- sentry-sdk: 1.9.4
- setproctitle: 1.2.2
- setuptools: 59.5.0
- shapely: 1.8.0
- shortuuid: 1.0.1
- simplejpeg: 1.4.1
- six: 1.16.0
- slackclient: 2.9.4
- smmap: 4.0.0
- sqlalchemy: 1.3.24
- tabulate: 0.8.9
- tangled-up-in-unicode: 0.1.0
- tensorboard: 2.6.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.0
- timm: 0.4.5
- toml: 0.10.2
- torch: 1.11.0.post1103
- torchmetrics: 0.7.0
- torchvision: 0.12.0a1110.post1103
- tqdm: 4.60.0
- traitlets: 5.3.0
- transforms3d: 0.3.1
- typing-extensions: 4.1.1
- urllib3: 1.26.11
- visions: 0.7.4
- wandb: 0.12.14
- wcwidth: 0.2.5
- werkzeug: 1.0.1
- wheel: 0.36.2
- wirerope: 0.3.1
- wrapt: 1.14.1
- xmltodict: 0.12.0
- xxhash: 1.4.1
- yarl: 1.6.3
-
System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.12
- version: /val_dataloader is not optional in distributed_backend='ddp' #138~18.04.1-Ubuntu SMP Fri Jun 24 14:14:03 UTC 2022
Additional context
cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @carmocca @justusschock