Skip to content

Commit 4a884d2

Browse files
Quasar-Kimpre-commit-ci[bot]awaelchli
authored andcommitted
Fix multithreading checkpoint loading (#17678)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: awaelchli <[email protected]> (cherry picked from commit 1307b60)
1 parent b810098 commit 4a884d2

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

src/lightning/pytorch/utilities/migration/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import os
1616
import sys
17+
import threading
1718
from types import ModuleType, TracebackType
1819
from typing import Any, Dict, List, Optional, Tuple, Type
1920

@@ -28,6 +29,7 @@
2829

2930
_log = logging.getLogger(__name__)
3031
_CHECKPOINT = Dict[str, Any]
32+
_lock = threading.Lock()
3133

3234

3335
def migrate_checkpoint(
@@ -85,6 +87,7 @@ class pl_legacy_patch:
8587
"""
8688

8789
def __enter__(self) -> "pl_legacy_patch":
90+
_lock.acquire()
8891
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
8992
legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils")
9093
sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module
@@ -103,6 +106,7 @@ def __exit__(
103106
if hasattr(pl.utilities.argparse, "_gpus_arg_default"):
104107
delattr(pl.utilities.argparse, "_gpus_arg_default")
105108
del sys.modules["lightning.pytorch.utilities.argparse_utils"]
109+
_lock.release()
106110

107111

108112
def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT:

tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import glob
1515
import os
1616
import sys
17-
import threading
1817
from unittest.mock import patch
1918

2019
import pytest
@@ -26,6 +25,7 @@
2625
from tests_pytorch.helpers.datamodules import ClassifDataModule
2726
from tests_pytorch.helpers.runif import RunIf
2827
from tests_pytorch.helpers.simple_models import ClassificationModel
28+
from tests_pytorch.helpers.threading import ThreadExceptionHandler
2929

3030
LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints")
3131
CHECKPOINT_EXTENSION = ".ckpt"
@@ -68,18 +68,22 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
6868
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
6969
@RunIf(sklearn=True)
7070
def test_legacy_ckpt_threading(tmpdir, pl_version: str):
71+
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
72+
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
73+
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
74+
path_ckpt = path_ckpts[-1]
75+
7176
def load_model():
7277
import torch
7378

7479
from lightning.pytorch.utilities.migration import pl_legacy_patch
7580

7681
with pl_legacy_patch():
77-
_ = torch.load(PATH_LEGACY)
82+
_ = torch.load(path_ckpt)
7883

79-
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
8084
with patch("sys.path", [PATH_LEGACY] + sys.path):
81-
t1 = threading.Thread(target=load_model)
82-
t2 = threading.Thread(target=load_model)
85+
t1 = ThreadExceptionHandler(target=load_model)
86+
t2 = ThreadExceptionHandler(target=load_model)
8387

8488
t1.start()
8589
t2.start()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from threading import Thread
15+
16+
17+
class ThreadExceptionHandler(Thread):
18+
"""Adopted from https://stackoverflow.com/a/67022927."""
19+
20+
def __init__(self, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
self.exception = None
23+
24+
def run(self):
25+
try:
26+
super().run()
27+
except Exception as e:
28+
self.exception = e
29+
30+
def join(self):
31+
super().join()
32+
if self.exception:
33+
raise self.exception

0 commit comments

Comments
 (0)