Skip to content

Commit 91aaa53

Browse files
lightningforeverawaelchlicarmocca
authored
Lite: Support self.log from a LightningModule (#16311)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent a56c12c commit 91aaa53

File tree

10 files changed

+194
-44
lines changed

10 files changed

+194
-44
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
* Added `Fabric.log` for logging scalars using multiple loggers
3636
* Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once
3737
* Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances
38-
38+
* Added support for calling `self.log` and `self.log_dict` in a LightningModule when using Fabric
39+
* Added access to `self.logger` and `self.loggers` in a LightningModule when using Fabric
3940

4041
- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))
4142

src/lightning_fabric/fabric.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from lightning_fabric.strategies.strategy import _Sharded, TBroadcast
4444
from lightning_fabric.utilities import move_data_to_device
45-
from lightning_fabric.utilities.apply_func import convert_to_tensors
45+
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, convert_to_tensors
4646
from lightning_fabric.utilities.data import (
4747
_auto_add_worker_init_fn,
4848
_replace_dunder_methods,
@@ -597,7 +597,8 @@ def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
597597
598598
Args:
599599
name: The name of the metric to log.
600-
value: The metric value to collect.
600+
value: The metric value to collect. If the value is a :class:`torch.Tensor`, it gets detached from the
601+
graph automatically.
601602
step: Optional step number. Most Logger implementations auto-increment the step value by one with every
602603
log call. You can specify your own value here.
603604
"""
@@ -608,9 +609,11 @@ def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
608609
609610
Args:
610611
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged.
612+
Any :class:`torch.Tensor` in the dictionary get detached from the graph automatically.
611613
step: Optional step number. Most Logger implementations auto-increment this value by one with every
612614
log call. You can specify your own value here.
613615
"""
616+
metrics = convert_tensors_to_scalars(metrics)
614617
for logger in self._loggers:
615618
logger.log_metrics(metrics=metrics, step=step)
616619

src/lightning_fabric/utilities/apply_func.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,21 @@ def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor:
111111

112112
# make sure existing tensors are in the correct device, also contiguous
113113
return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)
114+
115+
116+
def convert_tensors_to_scalars(data: Any) -> Any:
117+
"""Recursively walk through a collection and convert single-item tensors to scalar values.
118+
119+
Raises:
120+
ValueError:
121+
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
122+
"""
123+
124+
def to_item(value: Tensor) -> Union[int, float, bool]:
125+
if value.numel() != 1:
126+
raise ValueError(
127+
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
128+
)
129+
return value.item()
130+
131+
return apply_to_collection(data, Tensor, to_item)

src/pytorch_lightning/core/module.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ def log(
403403
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
404404
would produce a deadlock as not all processes would perform this log call.
405405
"""
406+
if self._fabric is not None:
407+
self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
408+
return
409+
406410
# check for invalid values
407411
apply_to_collection(value, dict, self.__check_not_nested, name)
408412
apply_to_collection(
@@ -554,22 +558,39 @@ def log_dict(
554558
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
555559
would produce a deadlock as not all processes would perform this log call.
556560
"""
557-
for k, v in dictionary.items():
558-
self.log(
559-
name=k,
560-
value=v,
561-
prog_bar=prog_bar,
562-
logger=logger,
563-
on_step=on_step,
564-
on_epoch=on_epoch,
565-
reduce_fx=reduce_fx,
566-
enable_graph=enable_graph,
567-
sync_dist=sync_dist,
568-
sync_dist_group=sync_dist_group,
569-
add_dataloader_idx=add_dataloader_idx,
570-
batch_size=batch_size,
571-
rank_zero_only=rank_zero_only,
572-
)
561+
if self._fabric is not None:
562+
self._log_dict_through_fabric(dictionary=dictionary, logger=logger) # type: ignore[arg-type]
563+
else:
564+
for k, v in dictionary.items():
565+
self.log(
566+
name=k,
567+
value=v,
568+
prog_bar=prog_bar,
569+
logger=logger,
570+
on_step=on_step,
571+
on_epoch=on_epoch,
572+
reduce_fx=reduce_fx,
573+
enable_graph=enable_graph,
574+
sync_dist=sync_dist,
575+
sync_dist_group=sync_dist_group,
576+
add_dataloader_idx=add_dataloader_idx,
577+
batch_size=batch_size,
578+
rank_zero_only=rank_zero_only,
579+
)
580+
581+
def _log_dict_through_fabric(self, dictionary: Dict[str, Any], logger: Optional[bool] = None) -> None:
582+
if logger is False:
583+
# Passing `logger=False` with Fabric does not make much sense because there is no other destination to
584+
# log to, but we support it in case the original code was written for Trainer use
585+
return
586+
587+
if any(isinstance(v, dict) for v in dictionary.values()):
588+
raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged")
589+
for name, value in dictionary.items():
590+
apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor))
591+
592+
assert self._fabric is not None
593+
self._fabric.log_dict(metrics=dictionary)
573594

574595
@staticmethod
575596
def __check_not_nested(value: dict, name: str) -> None:

src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import pytorch_lightning as pl
2020
from lightning_fabric.plugins.environments import SLURMEnvironment
2121
from lightning_fabric.utilities import move_data_to_device
22+
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
2223
from pytorch_lightning.loggers import Logger, TensorBoardLogger
2324
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
24-
from pytorch_lightning.utilities.metrics import metrics_to_scalars
2525

2626

2727
class LoggerConnector:
@@ -80,7 +80,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
8080
self._logged_metrics.update(metrics)
8181

8282
# turn all tensors to scalars
83-
scalar_metrics = metrics_to_scalars(metrics)
83+
scalar_metrics = convert_tensors_to_scalars(metrics)
8484

8585
if step is None:
8686
step = scalar_metrics.pop("step", None)

src/pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from typing_extensions import TypedDict
2323

2424
from lightning_fabric.utilities import move_data_to_device
25+
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
2526
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2627
from lightning_fabric.utilities.distributed import _distributed_available
2728
from pytorch_lightning.utilities.data import extract_batch_size
2829
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2930
from pytorch_lightning.utilities.imports import _fault_tolerant_training
3031
from pytorch_lightning.utilities.memory import recursive_detach
31-
from pytorch_lightning.utilities.metrics import metrics_to_scalars
3232
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
3333
from pytorch_lightning.utilities.warnings import PossibleUserWarning
3434

@@ -610,7 +610,7 @@ def any_tensor(_: Any) -> None:
610610

611611
# populate progress_bar metrics. convert tensors to numbers
612612
if result_metric.meta.prog_bar:
613-
metrics["pbar"][forked_name] = metrics_to_scalars(value)
613+
metrics["pbar"][forked_name] = convert_tensors_to_scalars(value)
614614

615615
return metrics
616616

src/pytorch_lightning/utilities/metrics.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Helper functions to operate on metric values."""
15-
from __future__ import annotations
1615

1716
from typing import Any
1817

19-
from lightning_utilities.core.apply_func import apply_to_collection
20-
from torch import Tensor
21-
22-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
18+
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
2319

2420

2521
def metrics_to_scalars(metrics: Any) -> Any:
2622
"""Recursively walk through a collection and convert single-item tensors to scalar values.
2723
2824
Raises:
29-
MisconfigurationException:
25+
ValueError:
3026
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
3127
"""
3228

33-
def to_item(value: Tensor) -> int | float | bool:
34-
if value.numel() != 1:
35-
raise MisconfigurationException(
36-
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
37-
)
38-
return value.item()
39-
40-
return apply_to_collection(metrics, Tensor, to_item)
29+
return convert_tensors_to_scalars(metrics)

tests/tests_fabric/test_fabric.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,28 @@ def test_log_dict():
801801
fabric.log_dict({"foo": 3, "bar": 4}, step=15)
802802
logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
803803
logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
804+
805+
806+
def test_log_dict_input_parsing():
807+
"""Test validation of input data types and preprocessing."""
808+
logger = Mock()
809+
fabric = Fabric(loggers=[logger])
810+
811+
# Tensor scalar, 0 dims
812+
fabric.log("log", torch.tensor(1))
813+
logger.log_metrics.assert_called_with(metrics={"log": 1}, step=None)
814+
fabric.log_dict({"log_dict": torch.tensor(1)})
815+
logger.log_metrics.assert_called_with(metrics={"log_dict": 1}, step=None)
816+
817+
# Tensor scalar, 1 dims
818+
fabric.log("log", torch.tensor([2]))
819+
logger.log_metrics.assert_called_with(metrics={"log": 2}, step=None)
820+
fabric.log_dict({"log_dict": torch.tensor([2])})
821+
logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None)
822+
823+
# Tensor, multiple dims
824+
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
825+
fabric.log("log", torch.tensor([3, 4]))
826+
827+
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
828+
fabric.log_dict({"log_dict": torch.tensor([3, 4])})

tests/tests_fabric/utilities/test_apply_func.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from torch import Tensor
1717

18-
from lightning_fabric.utilities.apply_func import move_data_to_device
18+
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device
1919

2020

2121
@pytest.mark.parametrize("should_return", [False, True])
@@ -34,3 +34,20 @@ def to(self, device):
3434
tensor = torch.tensor(0.1)
3535
obj = TensorObject(tensor, should_return)
3636
assert obj == move_data_to_device(obj, torch.device("cpu"))
37+
38+
39+
def test_convert_tensors_to_scalars():
40+
assert convert_tensors_to_scalars("string") == "string"
41+
assert convert_tensors_to_scalars(1) == 1
42+
assert convert_tensors_to_scalars(True) is True
43+
assert convert_tensors_to_scalars({"scalar": 1.0}) == {"scalar": 1.0}
44+
45+
result = convert_tensors_to_scalars({"tensor": torch.tensor(2.0)})
46+
# note: `==` comparison as above is not sufficient, since `torch.tensor(x) == x` evaluates to truth
47+
assert not isinstance(result["tensor"], Tensor) and result["tensor"] == 2.0
48+
49+
result = convert_tensors_to_scalars({"tensor": torch.tensor([2.0])})
50+
assert not isinstance(result["tensor"], Tensor) and result["tensor"] == 2.0
51+
52+
with pytest.raises(ValueError, match="does not contain a single element"):
53+
convert_tensors_to_scalars({"tensor": torch.tensor([1, 2, 3])})

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,89 @@ def test_fabric_attributes():
512512

513513
fabric = Fabric()
514514
wrapped_module, wrapped_optimizer = fabric.setup(module, optimizer)
515-
assert module.fabric is fabric
516-
assert module._fabric_optimizers == [wrapped_optimizer]
515+
assert wrapped_module.fabric is fabric
516+
assert wrapped_module._fabric_optimizers == [wrapped_optimizer]
517517

518518
# Attribute access on LightningModule.trainer gets redirected to Fabric
519-
assert isinstance(module.trainer, _TrainerFabricShim)
520-
assert module.trainer.global_rank == 0
519+
assert isinstance(wrapped_module.trainer, _TrainerFabricShim)
520+
assert wrapped_module.trainer.global_rank == 0
521521
with pytest.raises(AttributeError, match="Your LightningModule code tried to access `self.trainer.current_epoch`"):
522-
_ = module.trainer.current_epoch
522+
_ = wrapped_module.trainer.current_epoch
523523

524-
assert module.optimizers() == wrapped_optimizer
524+
assert wrapped_module.optimizers() == wrapped_optimizer
525+
526+
527+
def test_fabric_logger_access():
528+
"""Test that the logger attribute can be accessed when the LightningModule is used together with Fabric."""
529+
# No logger
530+
module = BoringModel()
531+
fabric = Fabric()
532+
wrapped_module = fabric.setup(module)
533+
assert wrapped_module.loggers == []
534+
with pytest.raises(IndexError):
535+
_ = wrapped_module.logger
536+
537+
# Single Logger
538+
logger = Mock()
539+
module = BoringModel()
540+
fabric = Fabric(loggers=logger)
541+
wrapped_module = fabric.setup(module)
542+
assert wrapped_module.logger == logger
543+
assert wrapped_module.loggers == [logger]
544+
545+
# Multiple loggers
546+
logger1 = Mock()
547+
logger2 = Mock()
548+
module = BoringModel()
549+
fabric = Fabric(loggers=[logger1, logger2])
550+
wrapped_module = fabric.setup(module)
551+
assert wrapped_module.logger == logger1
552+
assert wrapped_module.loggers == [logger1, logger2]
553+
554+
555+
def test_fabric_log():
556+
logger = Mock()
557+
module = BoringModel()
558+
fabric = Fabric(loggers=[logger])
559+
wrapped_module = fabric.setup(module)
560+
561+
# unsupported data type
562+
with pytest.raises(ValueError, match="`list` values cannot be logged"):
563+
wrapped_module.log("invalid", list())
564+
565+
# supported data types
566+
wrapped_module.log("int", 1)
567+
logger.log_metrics.assert_called_with(metrics={"int": 1}, step=None)
568+
wrapped_module.log("float", 0.1)
569+
logger.log_metrics.assert_called_with(metrics={"float": 0.1}, step=None)
570+
wrapped_module.log("tensor", torch.tensor(0.1))
571+
logger.log_metrics.assert_called_with(metrics={"tensor": torch.tensor(0.1)}, step=None)
572+
573+
# logger=False
574+
logger.reset_mock()
575+
wrapped_module.log("nothing", 1, logger=False)
576+
logger.log_metrics.assert_not_called()
577+
578+
579+
def test_fabric_log_dict():
580+
logger = Mock()
581+
module = BoringModel()
582+
fabric = Fabric(loggers=[logger])
583+
wrapped_module = fabric.setup(module)
584+
585+
# unsupported data type
586+
with pytest.raises(ValueError, match="`list` values cannot be logged"):
587+
wrapped_module.log_dict({"invalid": [1, 2, 3]})
588+
589+
# nested dicts
590+
with pytest.raises(ValueError, match="nested dictionaries cannot be logged"):
591+
wrapped_module.log_dict({"nested": {"nested": 1}})
592+
593+
# supported data types
594+
wrapped_module.log_dict({"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)})
595+
logger.log_metrics.assert_called_with(metrics={"int": 1, "float": 0.1, "tensor": torch.tensor(0.1)}, step=None)
596+
597+
# logger=False
598+
logger.reset_mock()
599+
wrapped_module.log_dict({"nothing": 1}, logger=False)
600+
logger.log_metrics.assert_not_called()

0 commit comments

Comments
 (0)