|
18 | 18 | from typing import Dict, List, Optional, Sequence, Union |
19 | 19 |
|
20 | 20 | import lightning.pytorch as pl |
| 21 | +from lightning.fabric.utilities.registry import _load_external_callbacks |
21 | 22 | from lightning.pytorch.callbacks import ( |
22 | 23 | Callback, |
23 | 24 | Checkpoint, |
|
33 | 34 | from lightning.pytorch.callbacks.timer import Timer |
34 | 35 | from lightning.pytorch.trainer import call |
35 | 36 | from lightning.pytorch.utilities.exceptions import MisconfigurationException |
36 | | -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 |
37 | 37 | from lightning.pytorch.utilities.model_helpers import is_overridden |
38 | 38 | from lightning.pytorch.utilities.rank_zero import rank_zero_info |
39 | 39 |
|
@@ -75,7 +75,7 @@ def on_trainer_init( |
75 | 75 | # configure the ModelSummary callback |
76 | 76 | self._configure_model_summary_callback(enable_model_summary) |
77 | 77 |
|
78 | | - self.trainer.callbacks.extend(_configure_external_callbacks()) |
| 78 | + self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory")) |
79 | 79 | _validate_callbacks_list(self.trainer.callbacks) |
80 | 80 |
|
81 | 81 | # push all model checkpoint callbacks to the end |
@@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: |
213 | 213 | return tuner_callbacks + other_callbacks + checkpoint_callbacks |
214 | 214 |
|
215 | 215 |
|
216 | | -def _configure_external_callbacks() -> List[Callback]: |
217 | | - """Collect external callbacks registered through entry points. |
218 | | -
|
219 | | - The entry points are expected to be functions returning a list of callbacks. |
220 | | -
|
221 | | - Return: |
222 | | - A list of all callbacks collected from external factories. |
223 | | - """ |
224 | | - group = "lightning.pytorch.callbacks_factory" |
225 | | - |
226 | | - if _PYTHON_GREATER_EQUAL_3_8_0: |
227 | | - from importlib.metadata import entry_points |
228 | | - |
229 | | - factories = ( |
230 | | - entry_points(group=group) |
231 | | - if _PYTHON_GREATER_EQUAL_3_10_0 |
232 | | - else entry_points().get(group, {}) # type: ignore[arg-type] |
233 | | - ) |
234 | | - else: |
235 | | - from pkg_resources import iter_entry_points |
236 | | - |
237 | | - factories = iter_entry_points(group) # type: ignore[assignment] |
238 | | - |
239 | | - external_callbacks: List[Callback] = [] |
240 | | - for factory in factories: |
241 | | - callback_factory = factory.load() |
242 | | - callbacks_list: Union[List[Callback], Callback] = callback_factory() |
243 | | - callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list |
244 | | - _log.info( |
245 | | - f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" |
246 | | - f" {', '.join(type(cb).__name__ for cb in callbacks_list)}" |
247 | | - ) |
248 | | - external_callbacks.extend(callbacks_list) |
249 | | - return external_callbacks |
250 | | - |
251 | | - |
252 | 216 | def _validate_callbacks_list(callbacks: List[Callback]) -> None: |
253 | 217 | stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] |
254 | 218 | seen_callbacks = set() |
|
0 commit comments