Skip to content

TrainingEpochLoop._should_check_val_fx discrepancy between continued run <> restore from ckpt #14579

@Anner-deJong

Description

@Anner-deJong

🐛 Bug

Found a discrepancy between a continued run after checkpointing, and restoring from checkpoint

Observation:

training_batch / val_loop ordering upon checkpoint restoration not the same as original run after checkpoint saving.

There are still the same amount of train steps, but the validation loops are interleaved at a single step later, which can cause the restored run to end up with one less validation loop (see colab)

Assumption / expectation:

Zero difference between a training run after a checkpoint and a run continued from said checkpoint

Investigation so far:

Im new to some of this lightning code, but IIUC:

Key:

TrainingEpochLoop's self.batch_progress.increment_completed() is called after "on_train_batch_end" hooks, the latter kicking off checkpoint saving.

  1. upon restoring, the TrainingEpochLoop.batch_progress.current.reset_on_restart() will reset the ready back to completed
  2. yet the global_step, which refers to TrainingEpochLoop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed, has been increment_completed() (called within TrainingEpochLoop.batch_loop.run) and thus upon restoring, ..optimizer.step.total.ready is set to an up to date optimizer.step.total.completed, out of sync with the above
  3. [simplification] in "val_check_interval mode", validation is triggered when TrainingEpochLoop.batch_progress.current.ready % val_check_interval == 0 (through TrainingEpochLoop.on_advance_end -> TrainingEpochLoop._should_check_val_fx
  4. combining the above three, the same batch_progress.current ready/completed counter for the continued and restored runs, end up aligned with different global_steps, and hence validation triggers at different global_steps

Another observation:

The following if statement seems to allow for a zero-difference restart, except that just like 4. above, _should_check_val_fx wouldnt trigger where in the original run on the checkpointing step it did (although there called in on_advance_end). Not sure if the original intention of this snippet included the current scope

class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
    ...
    def advance(self, data_fetcher: AbstractDataFetcher) -> None:  # type: ignore[override]
        ...
        if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
            # skip training and run validation in `on_advance_end`
            return

PR's relevant to this line:

Potential impact:

Assuming not too worrisome for the more default Lightning use cases:

  • With val_check_interval >> 3 (colab example = 3), or that turned off relying instead on check_val_every_n_epoch

However, in theory it can influence all of the following:

  • no 1:1 deterministic reproducibility
  • affect the latest/best validation loss
    • affects any code flow / decision making based on that
  • cause a "different usage order" of rngs (<- how I initially caught the issue, even with correctly restored rng states, if both validation and training steps use one theyll each end up with different random numbers as compared to the continued run)
  • other

To Reproduce

customized google colab bug_report_model.ipynb with same observation on BoringModel

Expected behavior

Zero difference between a training run continued after a checkpoint and a run continued from said checkpoint

Environment

Note:

  • The below is from original investigation in our own code base, with pytorch lightning v1.6.4.
  • The environment details from the BoringModel's reproduction are listed in the colab, with pytorch lighting v1.7.4
  • I also browsed through the master branch last weeks and the relevant code seems unchanged
Details
  • CUDA:
    • GPU:
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
    • available: True
    • version: 11.0
  • Lightning:
    • efficientnet-pytorch: 0.7.1
    • pytorch-lightning: 1.6.4
    • torch: 1.11.0.post1103
    • torchmetrics: 0.7.0
    • torchvision: 0.12.0a1110.post1103
  • Packages:
    • absl-py: 0.15.0
    • adal: 1.2.7
    • adlfs: 2021.10.0
    • aiohttp: 3.7.4
    • applicationinsights: 0.11.10
    • argcomplete: 1.12.3
    • async-timeout: 3.0.1
    • attrdict: 2.0.0
    • attrs: 21.1.0
    • av: 8.0.3
    • azure-cli-core: 2.38.0
    • azure-cli-telemetry: 1.0.6
    • azure-common: 1.1.27
    • azure-core: 1.20.0
    • azure-datalake-store: 0.0.52
    • azure-identity: 1.10.0
    • azure-keyvault-secrets: 4.2.0
    • azure-mgmt-core: 1.2.2
    • azure-storage-blob: 12.11.0
    • backcall: 0.2.0
    • backoff: 1.10.0
    • bcrypt: 3.2.0
    • cachetools: 4.2.2
    • certifi: 2020.12.5
    • cffi: 1.14.5
    • chardet: 3.0.4
    • charset-normalizer: 2.0.12
    • click: 7.1.2
    • confluent-kafka: 1.7.0
    • cryptography: 3.4.8
    • cycler: 0.10.0
    • datadog: 0.44.0
    • decorator: 5.0.7
    • deepdiff: 5.5.0
    • deltalake: 0.5.8
    • docker-pycreds: 0.4.0
    • efficientnet-pytorch: 0.7.1
    • einops: 0.4.1
    • filelock: 3.7.1
    • fonttools: 4.37.1
    • frozendict: 2.3.2
    • fsspec: 2022.1.0
    • gitdb: 4.0.7
    • gitpython: 3.1.14
    • google-auth: 1.30.0
    • google-auth-oauthlib: 0.4.4
    • grpcio: 1.37.1
    • htmlmin: 0.1.12
    • humanfriendly: 10.0
    • idna: 2.10
    • imagehash: 4.2.1
    • inplace-abn: 1.1.0a1110.post1103
    • ipdb: 0.13.9
    • ipython: 7.23.1
    • isodate: 0.6.0
    • jedi: 0.18.0
    • jinja2: 3.1.2
    • jmespath: 0.10.0
    • joblib: 1.0.1
    • kafka-python: 2.0.2
    • kiwisolver: 1.3.1
    • knack: 0.9.0
    • markdown: 3.3.4
    • markupsafe: 2.0.1
    • matplotlib: 3.5.3
    • matplotlib-inline: 0.1.2
    • methodtools: 0.1.2
    • missingno: 0.5.0
    • msal: 1.16.0
    • msal-extensions: 0.3.0
    • msrest: 0.6.21
    • msrestazure: 0.6.4
    • multidict: 5.1.0
    • multimethod: 1.6
    • networkx: 2.5.1
    • numpy: 1.22.4
    • oauthlib: 3.1.0
    • opencv-python: 4.4.0.44
    • ordered-set: 4.0.2
    • packaging: 21.3
    • pandas: 1.4.3
    • pandas-profiling: 3.1.0
    • paramiko: 2.7.2
    • parso: 0.8.2
    • pathtools: 0.1.2
    • pexpect: 4.8.0
    • phik: 0.12.0
    • pickleshare: 0.7.5
    • pillow: 9.2.0
    • pip: 22.0.3
    • pkginfo: 1.7.0
    • polyline: 1.4.0
    • portalocker: 1.7.1
    • prometheus-client: 0.8.0
    • promise: 2.3
    • prompt-toolkit: 2.0.10
    • protobuf: 3.15.8
    • psutil: 5.9.1
    • psycopg2: 2.8.3
    • ptyprocess: 0.7.0
    • py: 1.10.0
    • py3nvml: 0.2.7
    • pyarrow: 9.0.0
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycparser: 2.20
    • pydantic: 1.8.2
    • pydeprecate: 0.3.1
    • pygame: 2.1.2
    • pygments: 2.9.0
    • pyjwt: 1.7.1
    • pynacl: 1.4.0
    • pyntcloud: 0.1.6
    • pyopenssl: 20.0.1
    • pyparsing: 2.4.7
    • pyquaternion: 0.9.9
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-json-logger: 2.0.2
    • pytorch-lightning: 1.6.4
    • pytz: 2022.1
    • pywavelets: 1.1.1
    • pyyaml: 6.0
    • qrcode: 6.1
    • requests: 2.27.1
    • requests-oauthlib: 1.3.0
    • retry: 0.9.2
    • rsa: 4.7.2
    • runai: 0.3.0
    • scipy: 1.6.2
    • seaborn: 0.11.2
    • semver: 2.13.0
    • sentry-sdk: 1.9.4
    • setproctitle: 1.2.2
    • setuptools: 59.5.0
    • shapely: 1.8.0
    • shortuuid: 1.0.1
    • simplejpeg: 1.4.1
    • six: 1.16.0
    • slackclient: 2.9.4
    • smmap: 4.0.0
    • sqlalchemy: 1.3.24
    • tabulate: 0.8.9
    • tangled-up-in-unicode: 0.1.0
    • tensorboard: 2.6.0
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.0
    • timm: 0.4.5
    • toml: 0.10.2
    • torch: 1.11.0.post1103
    • torchmetrics: 0.7.0
    • torchvision: 0.12.0a1110.post1103
    • tqdm: 4.60.0
    • traitlets: 5.3.0
    • transforms3d: 0.3.1
    • typing-extensions: 4.1.1
    • urllib3: 1.26.11
    • visions: 0.7.4
    • wandb: 0.12.14
    • wcwidth: 0.2.5
    • werkzeug: 1.0.1
    • wheel: 0.36.2
    • wirerope: 0.3.1
    • wrapt: 1.14.1
    • xmltodict: 0.12.0
    • xxhash: 1.4.1
    • yarl: 1.6.3
  • System:

Additional context

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @carmocca @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcheckpointingRelated to checkpointinghelp wantedOpen to be worked onloopsRelated to the Loop API

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions