Skip to content

Commit f2f187f

Browse files
lightningforeverlantiga
authored andcommitted
External callback registry through entry points for Fabric (#17756)
1 parent 1ddd690 commit f2f187f

File tree

10 files changed

+127
-47
lines changed

10 files changed

+127
-47
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [UnReleased] - 2023-04-DD
99

10+
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
11+
12+
1013
### Changed
1114

1215
-

src/lightning/fabric/fabric.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
has_iterable_dataset,
4444
)
4545
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
46+
from lightning.fabric.utilities.registry import _load_external_callbacks
4647
from lightning.fabric.utilities.seed import seed_everything
4748
from lightning.fabric.utilities.types import ReduceOp
4849
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -105,8 +106,7 @@ def __init__(
105106
self._strategy: Strategy = self._connector.strategy
106107
self._accelerator: Accelerator = self._connector.accelerator
107108
self._precision: Precision = self._strategy.precision
108-
callbacks = callbacks if callbacks is not None else []
109-
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
109+
self._callbacks = self._configure_callbacks(callbacks)
110110
loggers = loggers if loggers is not None else []
111111
self._loggers = loggers if isinstance(loggers, list) else [loggers]
112112
self._models_setup: int = 0
@@ -846,6 +846,13 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
846846
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
847847
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
848848

849+
@staticmethod
850+
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
851+
callbacks = callbacks if callbacks is not None else []
852+
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
853+
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
854+
return callbacks
855+
849856

850857
def _is_using_cli() -> bool:
851858
return bool(int(os.environ.get("LT_CLI_USED", "0")))

src/lightning/fabric/utilities/imports.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@
3030
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
3131
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
3232
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
33+
34+
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
35+
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/fabric/utilities/registry.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15-
from typing import Any
15+
import logging
16+
from typing import Any, List, Union
17+
18+
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
19+
20+
_log = logging.getLogger(__name__)
1621

1722

1823
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
@@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
2530
return False
2631

2732
return mod_attr.__code__ is not super_attr.__code__
33+
34+
35+
def _load_external_callbacks(group: str) -> List[Any]:
36+
"""Collect external callbacks registered through entry points.
37+
38+
The entry points are expected to be functions returning a list of callbacks.
39+
40+
Args:
41+
group: The entry point group name to load callbacks from.
42+
43+
Return:
44+
A list of all callbacks collected from external factories.
45+
"""
46+
if _PYTHON_GREATER_EQUAL_3_8_0:
47+
from importlib.metadata import entry_points
48+
49+
factories = (
50+
entry_points(group=group)
51+
if _PYTHON_GREATER_EQUAL_3_10_0
52+
else entry_points().get(group, {}) # type: ignore[arg-type]
53+
)
54+
else:
55+
from pkg_resources import iter_entry_points
56+
57+
factories = iter_entry_points(group) # type: ignore[assignment]
58+
59+
external_callbacks: List[Any] = []
60+
for factory in factories:
61+
callback_factory = factory.load()
62+
callbacks_list: Union[List[Any], Any] = callback_factory()
63+
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
64+
_log.info(
65+
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
66+
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
67+
)
68+
external_callbacks.extend(callbacks_list)
69+
return external_callbacks

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Dict, List, Optional, Sequence, Union
1919

2020
import lightning.pytorch as pl
21+
from lightning.fabric.utilities.registry import _load_external_callbacks
2122
from lightning.pytorch.callbacks import (
2223
Callback,
2324
Checkpoint,
@@ -33,7 +34,6 @@
3334
from lightning.pytorch.callbacks.timer import Timer
3435
from lightning.pytorch.trainer import call
3536
from lightning.pytorch.utilities.exceptions import MisconfigurationException
36-
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
3737
from lightning.pytorch.utilities.model_helpers import is_overridden
3838
from lightning.pytorch.utilities.rank_zero import rank_zero_info
3939

@@ -75,7 +75,7 @@ def on_trainer_init(
7575
# configure the ModelSummary callback
7676
self._configure_model_summary_callback(enable_model_summary)
7777

78-
self.trainer.callbacks.extend(_configure_external_callbacks())
78+
self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
7979
_validate_callbacks_list(self.trainer.callbacks)
8080

8181
# push all model checkpoint callbacks to the end
@@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
213213
return tuner_callbacks + other_callbacks + checkpoint_callbacks
214214

215215

216-
def _configure_external_callbacks() -> List[Callback]:
217-
"""Collect external callbacks registered through entry points.
218-
219-
The entry points are expected to be functions returning a list of callbacks.
220-
221-
Return:
222-
A list of all callbacks collected from external factories.
223-
"""
224-
group = "lightning.pytorch.callbacks_factory"
225-
226-
if _PYTHON_GREATER_EQUAL_3_8_0:
227-
from importlib.metadata import entry_points
228-
229-
factories = (
230-
entry_points(group=group)
231-
if _PYTHON_GREATER_EQUAL_3_10_0
232-
else entry_points().get(group, {}) # type: ignore[arg-type]
233-
)
234-
else:
235-
from pkg_resources import iter_entry_points
236-
237-
factories = iter_entry_points(group) # type: ignore[assignment]
238-
239-
external_callbacks: List[Callback] = []
240-
for factory in factories:
241-
callback_factory = factory.load()
242-
callbacks_list: Union[List[Callback], Callback] = callback_factory()
243-
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
244-
_log.info(
245-
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
246-
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
247-
)
248-
external_callbacks.extend(callbacks_list)
249-
return external_callbacks
250-
251-
252216
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
253217
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
254218
seen_callbacks = set()

src/lightning/pytorch/trainer/connectors/signal_connector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
import lightning.pytorch as pl
1313
from lightning.fabric.plugins.environments import SLURMEnvironment
14-
from lightning.fabric.utilities.imports import _IS_WINDOWS
15-
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
14+
from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
1615
from lightning.pytorch.utilities.rank_zero import rank_zero_info
1716

1817
# copied from signal.pyi

src/lightning/pytorch/utilities/imports.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import torch
1818
from lightning_utilities.core.imports import package_available, RequirementCache
1919

20-
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
21-
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
2220
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
2321
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2422
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import contextlib
2+
from unittest import mock
3+
from unittest.mock import Mock
4+
5+
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
6+
from lightning.fabric.utilities.registry import _load_external_callbacks
7+
8+
9+
class ExternalCallback:
10+
"""A callback in another library that gets registered through entry points."""
11+
12+
pass
13+
14+
15+
def test_load_external_callbacks():
16+
"""Test that the connector collects Callback instances from factories registered through entry points."""
17+
18+
def factory_no_callback():
19+
return []
20+
21+
def factory_one_callback():
22+
return ExternalCallback()
23+
24+
def factory_one_callback_list():
25+
return [ExternalCallback()]
26+
27+
def factory_multiple_callbacks_list():
28+
return [ExternalCallback(), ExternalCallback()]
29+
30+
with _make_entry_point_query_mock(factory_no_callback):
31+
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
32+
assert callbacks == []
33+
34+
with _make_entry_point_query_mock(factory_one_callback):
35+
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
36+
assert isinstance(callbacks[0], ExternalCallback)
37+
38+
with _make_entry_point_query_mock(factory_one_callback_list):
39+
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
40+
assert isinstance(callbacks[0], ExternalCallback)
41+
42+
with _make_entry_point_query_mock(factory_multiple_callbacks_list):
43+
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
44+
assert isinstance(callbacks[0], ExternalCallback)
45+
assert isinstance(callbacks[1], ExternalCallback)
46+
47+
48+
@contextlib.contextmanager
49+
def _make_entry_point_query_mock(callback_factory):
50+
query_mock = Mock()
51+
entry_point = Mock()
52+
entry_point.name = "mocked"
53+
entry_point.load.return_value = callback_factory
54+
if _PYTHON_GREATER_EQUAL_3_10_0:
55+
query_mock.return_value = [entry_point]
56+
import_path = "importlib.metadata.entry_points"
57+
elif _PYTHON_GREATER_EQUAL_3_8_0:
58+
query_mock().get.return_value = [entry_point]
59+
import_path = "importlib.metadata.entry_points"
60+
else:
61+
query_mock.return_value = [entry_point]
62+
import_path = "pkg_resources.iter_entry_points"
63+
with mock.patch(import_path, query_mock):
64+
yield

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
import torch
2121

22+
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
2223
from lightning.pytorch import Callback, LightningModule, Trainer
2324
from lightning.pytorch.callbacks import (
2425
EarlyStopping,
@@ -32,7 +33,6 @@
3233
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
3334
from lightning.pytorch.demos.boring_classes import BoringModel
3435
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
35-
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
3636

3737

3838
def test_checkpoint_callbacks_are_last(tmpdir):

tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
import torch
2626
from torch import Tensor
2727

28+
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
2829
from lightning.pytorch import callbacks, Trainer
2930
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
3031
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
3132
from lightning.pytorch.loggers import TensorBoardLogger
3233
from lightning.pytorch.loops import _EvaluationLoop
3334
from lightning.pytorch.trainer.states import RunningStage
3435
from lightning.pytorch.utilities.exceptions import MisconfigurationException
35-
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
3636
from tests_pytorch.helpers.runif import RunIf
3737

3838
if _RICH_AVAILABLE:

0 commit comments

Comments
 (0)