Skip to content

Commit 44ee516

Browse files
authored
Merge branch 'master' into fix/rich_progressbar
2 parents 21bae14 + 79ffe50 commit 44ee516

File tree

26 files changed

+632
-118
lines changed

26 files changed

+632
-118
lines changed

.github/CODEOWNERS

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
# the repo. Unless a later match takes precedence,
66
# @global-owner1 and @global-owner2 will be requested for
77
# review when someone opens a pull request.
8-
* @lantiga @borda @tchaton @justusschock @ethanwharris
8+
* @lantiga @tchaton @justusschock @ethanwharris
99

1010
# Docs
11-
/.github/*.md @williamfalcon @lantiga @borda
11+
/.github/*.md @williamfalcon @lantiga
1212
/docs/source-fabric/index.rst @williamfalcon @lantiga
1313
/docs/source-pytorch/index.rst @williamfalcon @lantiga
1414
/docs/source-pytorch/levels @williamfalcon @lantiga
1515

1616
/.github/CODEOWNERS @williamfalcon
1717
/SECURITY.md @williamfalcon @lantiga
1818
/README.md @williamfalcon @lantiga
19-
/src/pytorch_lightning/__about__.py @williamfalcon @lantiga @borda
20-
/src/lightning_fabric/__about__.py @williamfalcon @lantiga @borda
19+
/src/pytorch_lightning/__about__.py @williamfalcon @lantiga
20+
/src/lightning_fabric/__about__.py @williamfalcon @lantiga
2121

2222
/src/lightning/fabric/loggers @williamfalcon
2323
/src/lightning/pytorch/loggers @williamfalcon

.github/CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ We welcome any useful contribution! For your convenience here's a recommended wo
212212
- [Test README](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/README.md)
213213
- [CI/CD README](https://github.com/Lightning-AI/pytorch-lightning/tree/master/.github/workflows#readme)
214214

215+
1. Once you have a PR opened (and thereby a PR number), please update the respective changelog for [fabric](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/CHANGELOG.md) or [pytorch](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/CHANGELOG.md) subpackage depending on where you made your changes.
216+
215217
1. When you feel ready for integrating your work, mark your PR "Ready for review".
216218

217219
- Your code should be readable and follow the project's design principles.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch
1010

1111
# In Lightning Studio, the `lightning` package comes pre-installed.
1212
# Uninstall it first to ensure the editable install works correctly.
13-
setup:
13+
setup: update
1414
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1515
uv pip install -r requirements.txt \
1616
-r requirements/pytorch/base.txt \

docs/source-fabric/guide/callbacks.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ The :meth:`~lightning.fabric.fabric.Fabric.call` calls the callback objects in t
8383
Not all objects registered via ``Fabric(callbacks=...)`` must implement a method with the given name.
8484
The ones that have a matching method name will get called.
8585

86+
The different callbacks can have different method signatures. Fabric automatically filters keyword arguments based on
87+
each callback's function signature, allowing callbacks with different signatures to work together seamlessly.
88+
89+
.. code-block:: python
90+
91+
class TrainingMetricsCallback:
92+
def on_train_epoch_end(self, train_loss):
93+
print(f"Training loss: {train_loss:.4f}")
94+
95+
class ValidationMetricsCallback:
96+
def on_train_epoch_end(self, val_accuracy):
97+
print(f"Validation accuracy: {val_accuracy:.4f}")
98+
99+
class ComprehensiveCallback:
100+
def on_train_epoch_end(self, epoch, **kwargs):
101+
print(f"Epoch {epoch} complete with metrics: {kwargs}")
102+
103+
fabric = Fabric(
104+
callbacks=[TrainingMetricsCallback(), ValidationMetricsCallback(), ComprehensiveCallback()]
105+
)
106+
107+
# Each callback receives only the arguments it can handle
108+
fabric.call("on_train_epoch_end", epoch=5, train_loss=0.1, val_accuracy=0.95, learning_rate=0.001)
109+
86110
87111
----
88112

docs/source-pytorch/community/governance.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ Role: All final decisions related to Lightning.
1919
Maintainers
2020
-----------
2121
- Luca Antiga (`lantiga <https://github.com/lantiga>`_)
22-
- Jirka Borovec (`Borda <https://github.com/Borda>`_)
22+
- Ethan Harris (`ethanwharris <https://github.com/ethanwharris>`_) (Torchbearer founder)
2323
- Justus Schock (`justusschock <https://github.com/justusschock>`_)
2424

2525

2626
Emeritus Maintainers
2727
--------------------
28-
- Ethan Harris (`ethanwharris <https://github.com/ethanwharris>`_) (Torchbearer founder)
2928
- Nicki Skafte (`SkafteNicki <https://github.com/SkafteNicki>`_)
3029
- Thomas Chaton (`tchaton <https://github.com/tchaton>`_)
30+
- Jirka Borovec (`Borda <https://github.com/Borda>`_)
3131

3232

3333
Alumni

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.11.0
66
omegaconf >=2.2.3, <2.4.0
77
hydra-core >=1.2.0, <1.4.0
8-
jsonargparse[signatures,jsonnet] >=4.39.0, <4.43.0
8+
jsonargparse[signatures,jsonnet] >=4.39.0, <4.44.0
99
rich >=12.3.0, <14.3.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"

src/lightning/fabric/CHANGELOG.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,28 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [unreleased] - YYYY-MM-DD
8+
## [Unreleased] - YYYY-MM-DD
99

1010
### Added
1111

1212
-
1313

14+
### Changed
15+
16+
-
1417

1518
### Removed
1619

1720
-
1821

1922

23+
## [2.6.0] - 2025-11-28
24+
25+
### Added
26+
27+
- Added kwargs-filtering for `Fabric.call` to support different callback method signatures ([#21258](https://github.com/Lightning-AI/pytorch-lightning/pull/21258))
28+
29+
2030
### Changed
2131

2232
- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))
@@ -25,7 +35,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2535

2636
### Fixed
2737

28-
-
38+
- Fixed issue in detecting MPIEnvironment with partial mpi4py installation ([#21353](https://github.com/Lightning-AI/pytorch-lightning/pull/21353))
39+
40+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
41+
42+
43+
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
2944

3045

3146
---

src/lightning/fabric/fabric.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,34 @@ def train_function(fabric):
985985
)
986986
return self._wrap_and_launch(function, self, *args, **kwargs)
987987

988+
def _filter_kwargs_for_callback(self, method: Callable, kwargs: dict[str, Any]) -> dict[str, Any]:
989+
"""Filter keyword arguments to only include those that match the callback method's signature.
990+
991+
Args:
992+
method: The callback method to inspect
993+
kwargs: The keyword arguments to filter
994+
995+
Returns:
996+
A filtered dictionary of keyword arguments that match the method's signature
997+
998+
"""
999+
try:
1000+
sig = inspect.signature(method)
1001+
except (ValueError, TypeError):
1002+
# If we can't inspect the signature, pass all kwargs to maintain backward compatibility
1003+
return kwargs
1004+
1005+
filtered_kwargs = {}
1006+
for name, param in sig.parameters.items():
1007+
# If the method accepts **kwargs, pass all original kwargs directly
1008+
if param.kind == inspect.Parameter.VAR_KEYWORD:
1009+
return kwargs
1010+
# If the parameter exists in the incoming kwargs, add it to filtered_kwargs
1011+
if name in kwargs:
1012+
filtered_kwargs[name] = kwargs[name]
1013+
1014+
return filtered_kwargs
1015+
9881016
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9891017
r"""Trigger the callback methods with the given name and arguments.
9901018
@@ -994,7 +1022,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9941022
Args:
9951023
hook_name: The name of the callback method.
9961024
*args: Optional positional arguments that get passed down to the callback method.
997-
**kwargs: Optional keyword arguments that get passed down to the callback method.
1025+
**kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
1026+
that are not present in the callback's signature will be filtered out automatically, allowing
1027+
callbacks to have different signatures for the same hook.
9981028
9991029
Example::
10001030
@@ -1016,13 +1046,8 @@ def on_train_epoch_end(self, results):
10161046
)
10171047
continue
10181048

1019-
method(*args, **kwargs)
1020-
1021-
# TODO(fabric): handle the following signatures
1022-
# method(self, fabric|trainer, x, y=1)
1023-
# method(self, fabric|trainer, *args, x, y=1)
1024-
# method(self, *args, y=1)
1025-
# method(self, *args, **kwargs)
1049+
filtered_kwargs = self._filter_kwargs_for_callback(method, kwargs)
1050+
method(*args, **filtered_kwargs)
10261051

10271052
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
10281053
"""Log a scalar to all loggers that were added to Fabric.

src/lightning/fabric/plugins/environments/mpi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def detect() -> bool:
7373
if not _MPI4PY_AVAILABLE:
7474
return False
7575

76-
from mpi4py import MPI
76+
try:
77+
# mpi4py may be installed without MPI being present
78+
from mpi4py import MPI
79+
except ImportError:
80+
return False
7781

7882
return MPI.COMM_WORLD.Get_size() > 1
7983

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
2525
from lightning.fabric.plugins.precision.precision import Precision
2626
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
27+
from lightning.fabric.utilities import rank_zero_warn
2728
from lightning.fabric.utilities.types import Optimizable
2829

2930
if TYPE_CHECKING:
@@ -84,19 +85,18 @@ def convert_module(self, module: Module) -> Module:
8485
def mixed_precision_config(self) -> "TorchMixedPrecision":
8586
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
8687

87-
if self.precision == "16-mixed":
88-
param_dtype = torch.float32
89-
reduce_dtype = buffer_dtype = torch.float16
90-
elif self.precision == "bf16-mixed":
91-
param_dtype = torch.float32
92-
reduce_dtype = buffer_dtype = torch.bfloat16
93-
elif self.precision == "16-true":
88+
if self.precision in ("16-true", "bf16-true"):
89+
rank_zero_warn(
90+
f"FSDP with `{self.precision}` enables computation in lower precision. "
91+
"FSDP will always retain a full-precision copy of the model parameters for sharding."
92+
)
93+
94+
if self.precision in ("16-true", "16-mixed"):
9495
param_dtype = reduce_dtype = buffer_dtype = torch.float16
95-
elif self.precision == "bf16-true":
96+
elif self.precision in ("bf16-true", "bf16-mixed"):
9697
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
9798
elif self.precision == "32-true":
98-
param_dtype = torch.float32
99-
reduce_dtype = buffer_dtype = torch.float32
99+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
100100
else:
101101
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
102102

0 commit comments

Comments
 (0)