Skip to content

Commit fb7dbc8

Browse files
Merge branch 'bugfix/20350_grad_acc_fix' of https://github.com/Sohaib-Ahmed21/pytorch-lightning into bugfix/20350_grad_acc_fix
2 parents 95d467d + 7f5f88c commit fb7dbc8

File tree

73 files changed

+4344
-1109
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+4344
-1109
lines changed

.github/workflows/docs-build.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@ jobs:
125125
working-directory: ./docs/source-${{ matrix.pkg-name }}
126126
# allow failing link check and doctest if you run with dispatch
127127
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
128-
run: |
129-
# temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
130-
uv pip install -U fastapi
131-
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
128+
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
132129

133130
- name: Keep artifact
134131
if: github.event_name == 'pull_request'

docs/source-fabric/guide/callbacks.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ The :meth:`~lightning.fabric.fabric.Fabric.call` calls the callback objects in t
8383
Not all objects registered via ``Fabric(callbacks=...)`` must implement a method with the given name.
8484
The ones that have a matching method name will get called.
8585

86+
The different callbacks can have different method signatures. Fabric automatically filters keyword arguments based on
87+
each callback's function signature, allowing callbacks with different signatures to work together seamlessly.
88+
89+
.. code-block:: python
90+
91+
class TrainingMetricsCallback:
92+
def on_train_epoch_end(self, train_loss):
93+
print(f"Training loss: {train_loss:.4f}")
94+
95+
class ValidationMetricsCallback:
96+
def on_train_epoch_end(self, val_accuracy):
97+
print(f"Validation accuracy: {val_accuracy:.4f}")
98+
99+
class ComprehensiveCallback:
100+
def on_train_epoch_end(self, epoch, **kwargs):
101+
print(f"Epoch {epoch} complete with metrics: {kwargs}")
102+
103+
fabric = Fabric(
104+
callbacks=[TrainingMetricsCallback(), ValidationMetricsCallback(), ComprehensiveCallback()]
105+
)
106+
107+
# Each callback receives only the arguments it can handle
108+
fabric.call("on_train_epoch_end", epoch=5, train_loss=0.1, val_accuracy=0.95, learning_rate=0.001)
109+
86110
87111
----
88112

docs/source-pytorch/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,10 @@ def package_list_from_file(file):
604604
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
605605
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
606606
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
607-
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
607+
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
608+
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
609+
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
610+
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
608611
"""
609612
coverage_skip_undoc_in_source = True
610613

requirements/fabric/base.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ fsspec[http] >=2022.5.0, <2025.11.0
66
packaging >=20.0, <=25.0
77
typing-extensions >4.5.0, <4.16.0
88
lightning-utilities >=0.10.0, <0.16.0
9-
pytorch-lightning-enterprise >=2.6.0

requirements/pytorch/base.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ torchmetrics >0.7.0, <1.9.0
99
packaging >=20.0, <=25.0
1010
typing-extensions >4.5.0, <4.16.0
1111
lightning-utilities >=0.10.0, <0.16.0
12-
pytorch-lightning-enterprise >=2.6.0

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.11.0
66
omegaconf >=2.2.3, <2.4.0
77
hydra-core >=1.2.0, <1.4.0
8-
jsonargparse[signatures,jsonnet] >=4.39.0, <4.43.0
8+
jsonargparse[signatures,jsonnet] >=4.39.0, <4.44.0
99
rich >=12.3.0, <14.3.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"

src/lightning/fabric/CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
8+
## [unreleased] - YYYY-MM-DD
9+
10+
### Added
11+
12+
- Added kwargs-filtering for `Fabric.call` to support different callback method signatures ([#21258](https://github.com/Lightning-AI/pytorch-lightning/pull/21258))
13+
14+
15+
### Removed
16+
17+
-
18+
19+
20+
---
21+
722
## [2.6.0] - 2025-11-21
823

924
### Changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15-
import warnings
1615
from typing import Any, Union
1716

1817
import torch
@@ -21,11 +20,7 @@
2120

2221
from lightning.fabric.accelerators.accelerator import Accelerator
2322
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
24-
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
25-
26-
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
27-
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
28-
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
23+
from lightning.fabric.utilities.device_parser import _check_data_type
2924

3025

3126
class XLAAccelerator(Accelerator):
@@ -36,38 +31,38 @@ class XLAAccelerator(Accelerator):
3631
"""
3732

3833
def __init__(self, *args: Any, **kwargs: Any) -> None:
39-
_raise_enterprise_not_available()
34+
if not _XLA_AVAILABLE:
35+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
36+
if not _using_pjrt():
37+
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
4038
super().__init__(*args, **kwargs)
4139

42-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
43-
44-
self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
45-
4640
@override
4741
def setup_device(self, device: torch.device) -> None:
48-
return self.accelerator_impl.setup_device(device)
42+
pass
4943

5044
@override
5145
def teardown(self) -> None:
52-
return self.accelerator_impl.teardown()
46+
pass
5347

5448
@staticmethod
5549
@override
5650
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
5751
"""Accelerator device parsing logic."""
58-
_raise_enterprise_not_available()
59-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
60-
61-
return EnterpriseXLAAccelerator.parse_devices(devices)
52+
return _parse_tpu_devices(devices)
6253

6354
@staticmethod
6455
@override
6556
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
6657
"""Gets parallel devices for the Accelerator."""
67-
_raise_enterprise_not_available()
68-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
69-
70-
return EnterpriseXLAAccelerator.get_parallel_devices(devices)
58+
devices = _parse_tpu_devices(devices)
59+
if isinstance(devices, int):
60+
return [torch.device("xla", i) for i in range(devices)]
61+
# list of devices is not supported, just a specific index, fine to access [0]
62+
return [torch.device("xla", devices[0])]
63+
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64+
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65+
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
7166

7267
@staticmethod
7368
@override
@@ -76,10 +71,16 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
7671
@functools.lru_cache(maxsize=1)
7772
def auto_device_count() -> int:
7873
"""Get the devices when set to auto."""
79-
_raise_enterprise_not_available()
80-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
74+
if not _XLA_AVAILABLE:
75+
return 0
76+
if _XLA_GREATER_EQUAL_2_1:
77+
from torch_xla._internal import tpu
78+
79+
return tpu.num_available_devices()
80+
from torch_xla.experimental import tpu
8181

82-
return EnterpriseXLAAccelerator.auto_device_count()
82+
device_count_on_version = {2: 8, 3: 8, 4: 4}
83+
return device_count_on_version.get(tpu.version(), 8)
8384

8485
@staticmethod
8586
@override
@@ -91,9 +92,6 @@ def is_available() -> bool:
9192
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9293
# when `torch_xla` is imported but not used
9394
return False
94-
except ModuleNotFoundError as e:
95-
warnings.warn(str(e))
96-
return False
9795

9896
@staticmethod
9997
@override
@@ -108,3 +106,74 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
108106
cls,
109107
description=cls.__name__,
110108
)
109+
110+
111+
# PJRT support requires this minimum version
112+
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
113+
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
114+
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
115+
116+
117+
def _using_pjrt() -> bool:
118+
# `using_pjrt` is removed in torch_xla 2.5
119+
if _XLA_GREATER_EQUAL_2_5:
120+
from torch_xla import runtime as xr
121+
122+
return xr.device_type() is not None
123+
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124+
if _XLA_GREATER_EQUAL_2_1:
125+
from torch_xla import runtime as xr
126+
127+
return xr.using_pjrt()
128+
129+
from torch_xla.experimental import pjrt
130+
131+
return pjrt.using_pjrt()
132+
133+
134+
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
135+
"""Parses the TPU devices given in the format as accepted by the
136+
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137+
138+
Args:
139+
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140+
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141+
A single element list of int or string can be used to indicate the specific TPU core to use.
142+
143+
Returns:
144+
A list of tpu cores to be used.
145+
146+
"""
147+
_check_data_type(devices)
148+
if isinstance(devices, str):
149+
devices = _parse_tpu_devices_str(devices)
150+
_check_tpu_devices_valid(devices)
151+
return devices
152+
153+
154+
def _check_tpu_devices_valid(devices: object) -> None:
155+
device_count = XLAAccelerator.auto_device_count()
156+
if (
157+
# support number of devices
158+
isinstance(devices, int)
159+
and devices in {1, device_count}
160+
# support picking a specific device
161+
or isinstance(devices, (list, tuple))
162+
and len(devices) == 1
163+
and 0 <= devices[0] <= device_count - 1
164+
):
165+
return
166+
raise ValueError(
167+
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
168+
)
169+
170+
171+
def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
172+
devices = devices.strip()
173+
try:
174+
return int(devices)
175+
except ValueError:
176+
try:
177+
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
178+
except ValueError:
179+
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")

src/lightning/fabric/fabric.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,34 @@ def train_function(fabric):
985985
)
986986
return self._wrap_and_launch(function, self, *args, **kwargs)
987987

988+
def _filter_kwargs_for_callback(self, method: Callable, kwargs: dict[str, Any]) -> dict[str, Any]:
989+
"""Filter keyword arguments to only include those that match the callback method's signature.
990+
991+
Args:
992+
method: The callback method to inspect
993+
kwargs: The keyword arguments to filter
994+
995+
Returns:
996+
A filtered dictionary of keyword arguments that match the method's signature
997+
998+
"""
999+
try:
1000+
sig = inspect.signature(method)
1001+
except (ValueError, TypeError):
1002+
# If we can't inspect the signature, pass all kwargs to maintain backward compatibility
1003+
return kwargs
1004+
1005+
filtered_kwargs = {}
1006+
for name, param in sig.parameters.items():
1007+
# If the method accepts **kwargs, pass all original kwargs directly
1008+
if param.kind == inspect.Parameter.VAR_KEYWORD:
1009+
return kwargs
1010+
# If the parameter exists in the incoming kwargs, add it to filtered_kwargs
1011+
if name in kwargs:
1012+
filtered_kwargs[name] = kwargs[name]
1013+
1014+
return filtered_kwargs
1015+
9881016
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9891017
r"""Trigger the callback methods with the given name and arguments.
9901018
@@ -994,7 +1022,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9941022
Args:
9951023
hook_name: The name of the callback method.
9961024
*args: Optional positional arguments that get passed down to the callback method.
997-
**kwargs: Optional keyword arguments that get passed down to the callback method.
1025+
**kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
1026+
that are not present in the callback's signature will be filtered out automatically, allowing
1027+
callbacks to have different signatures for the same hook.
9981028
9991029
Example::
10001030
@@ -1016,13 +1046,8 @@ def on_train_epoch_end(self, results):
10161046
)
10171047
continue
10181048

1019-
method(*args, **kwargs)
1020-
1021-
# TODO(fabric): handle the following signatures
1022-
# method(self, fabric|trainer, x, y=1)
1023-
# method(self, fabric|trainer, *args, x, y=1)
1024-
# method(self, *args, y=1)
1025-
# method(self, *args, **kwargs)
1049+
filtered_kwargs = self._filter_kwargs_for_callback(method, kwargs)
1050+
method(*args, **filtered_kwargs)
10261051

10271052
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
10281053
"""Log a scalar to all loggers that were added to Fabric.

0 commit comments

Comments
 (0)