Skip to content

Commit 929c530

Browse files
committed
fix(RichProgressBar): Convert tensor metrics to float
The RichProgressBar was failing in multi-GPU environments because it could not handle tensor metrics from different devices. This commit overrides the get_metrics method to convert all tensor metrics to floats before they are rendered, preventing errors. An accompanying test is added to verify the fix.
1 parent 63f9009 commit 929c530

File tree

3 files changed

+155
-103
lines changed

3 files changed

+155
-103
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from datetime import timedelta
1818
from typing import Any, Optional, Union, cast
1919

20+
import torch
2021
from lightning_utilities.core.imports import RequirementCache
2122
from typing_extensions import override
2223

@@ -612,6 +613,17 @@ def _reset_progress_bar_ids(self) -> None:
612613
self.test_progress_bar_id = None
613614
self.predict_progress_bar_id = None
614615

616+
@override
617+
def get_metrics(
618+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
619+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
620+
items = super().get_metrics(trainer, pl_module)
621+
# convert all metrics to float before sending to rich
622+
for k, v in items.items():
623+
if isinstance(v, torch.Tensor):
624+
items[k] = v.item()
625+
return items
626+
615627
def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
616628
metrics = self.get_metrics(trainer, pl_module)
617629
if self._metric_component:

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 109 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -109,98 +109,98 @@ def test_tqdm_progress_bar_misconfiguration():
109109
Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False)
110110

111111

112+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
112113
@pytest.mark.parametrize("num_dl", [1, 2])
113114
def test_tqdm_progress_bar_totals(tmp_path, num_dl):
114115
"""Test that the progress finishes with the correct total steps processed."""
115-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
116-
117-
class CustomModel(BoringModel):
118-
def _get_dataloaders(self):
119-
dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
120-
return dls[0] if num_dl == 1 else dls
121-
122-
def val_dataloader(self):
123-
return self._get_dataloaders()
124-
125-
def test_dataloader(self):
126-
return self._get_dataloaders()
127-
128-
def predict_dataloader(self):
129-
return self._get_dataloaders()
130-
131-
def validation_step(self, batch, batch_idx, dataloader_idx=0):
132-
return
133-
134-
def test_step(self, batch, batch_idx, dataloader_idx=0):
135-
return
136-
137-
def predict_step(self, batch, batch_idx, dataloader_idx=0):
138-
return
139-
140-
model = CustomModel()
141-
142-
# check the sanity dataloaders
143-
num_sanity_val_steps = 4
144-
trainer = Trainer(
145-
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps
146-
)
147-
pbar = trainer.progress_bar_callback
148-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
149-
trainer.fit(model)
150-
151-
expected_sanity_steps = [num_sanity_val_steps] * num_dl
152-
assert not pbar.val_progress_bar.leave
153-
assert trainer.num_sanity_val_batches == expected_sanity_steps
154-
assert pbar.val_progress_bar.total_values == expected_sanity_steps
155-
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
156-
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
157-
158-
# fit
159-
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
160-
pbar = trainer.progress_bar_callback
161-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
162-
trainer.fit(model)
163-
164-
n = trainer.num_training_batches
165-
m = trainer.num_val_batches
166-
assert len(trainer.train_dataloader) == n
167-
# train progress bar should have reached the end
168-
assert pbar.train_progress_bar.total == n
169-
assert pbar.train_progress_bar.n == n
170-
assert pbar.train_progress_bar.leave
171-
172-
# check val progress bar total
173-
assert pbar.val_progress_bar.total_values == m
174-
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
175-
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
176-
assert not pbar.val_progress_bar.leave
177-
178-
# validate
179-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
180-
trainer.validate(model)
181-
assert trainer.num_val_batches == m
182-
assert pbar.val_progress_bar.total_values == m
183-
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
184-
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
185-
186-
# test
187-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
188-
trainer.test(model)
189-
assert pbar.test_progress_bar.leave
190-
k = trainer.num_test_batches
191-
assert pbar.test_progress_bar.total_values == k
192-
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
193-
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
194-
assert pbar.test_progress_bar.leave
195-
196-
# predict
197-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
198-
trainer.predict(model)
199-
assert pbar.predict_progress_bar.leave
200-
k = trainer.num_predict_batches
201-
assert pbar.predict_progress_bar.total_values == k
202-
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
203-
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
116+
117+
class CustomModel(BoringModel):
118+
def _get_dataloaders(self):
119+
dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
120+
return dls[0] if num_dl == 1 else dls
121+
122+
def val_dataloader(self):
123+
return self._get_dataloaders()
124+
125+
def test_dataloader(self):
126+
return self._get_dataloaders()
127+
128+
def predict_dataloader(self):
129+
return self._get_dataloaders()
130+
131+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
132+
return
133+
134+
def test_step(self, batch, batch_idx, dataloader_idx=0):
135+
return
136+
137+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
138+
return
139+
140+
model = CustomModel()
141+
142+
# check the sanity dataloaders
143+
num_sanity_val_steps = 4
144+
trainer = Trainer(
145+
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps
146+
)
147+
pbar = trainer.progress_bar_callback
148+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
149+
trainer.fit(model)
150+
151+
expected_sanity_steps = [num_sanity_val_steps] * num_dl
152+
assert not pbar.val_progress_bar.leave
153+
assert trainer.num_sanity_val_batches == expected_sanity_steps
154+
assert pbar.val_progress_bar.total_values == expected_sanity_steps
155+
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
156+
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
157+
158+
# fit
159+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
160+
pbar = trainer.progress_bar_callback
161+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
162+
trainer.fit(model)
163+
164+
n = trainer.num_training_batches
165+
m = trainer.num_val_batches
166+
assert len(trainer.train_dataloader) == n
167+
# train progress bar should have reached the end
168+
assert pbar.train_progress_bar.total == n
169+
assert pbar.train_progress_bar.n == n
170+
assert pbar.train_progress_bar.leave
171+
172+
# check val progress bar total
173+
assert pbar.val_progress_bar.total_values == m
174+
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
175+
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
176+
assert not pbar.val_progress_bar.leave
177+
178+
# validate
179+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
180+
trainer.validate(model)
181+
assert trainer.num_val_batches == m
182+
assert pbar.val_progress_bar.total_values == m
183+
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
184+
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
185+
186+
# test
187+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
188+
trainer.test(model)
189+
assert pbar.test_progress_bar.leave
190+
k = trainer.num_test_batches
191+
assert pbar.test_progress_bar.total_values == k
192+
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
193+
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
194+
assert pbar.test_progress_bar.leave
195+
196+
# predict
197+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
198+
trainer.predict(model)
199+
assert pbar.predict_progress_bar.leave
200+
k = trainer.num_predict_batches
201+
assert pbar.predict_progress_bar.total_values == k
202+
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
203+
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
204204
assert pbar.predict_progress_bar.leave
205205

206206

@@ -414,24 +414,30 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
414414
assert progress_bar.test_progress_bar.n_values == updates
415415

416416

417+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
417418
def test_tensor_to_float_conversion(tmp_path):
418419
"""Check tensor gets converted to float."""
419-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
420-
421-
class TestModel(BoringModel):
422-
def training_step(self, batch, batch_idx):
423-
self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False)
424-
self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False)
425-
self.log("c", 2, prog_bar=True, on_epoch=False)
426-
return super().training_step(batch, batch_idx)
427-
428-
trainer = Trainer(
429-
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
430-
)
420+
421+
class TestModel(BoringModel):
422+
def training_step(self, batch, batch_idx):
423+
self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False)
424+
self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False)
425+
self.log("c", 2, prog_bar=True, on_epoch=False)
426+
return super().training_step(batch, batch_idx)
427+
428+
trainer = Trainer(
429+
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
430+
)
431+
432+
with mock.patch.object(sys.stdout, "write") as mock_write:
431433
trainer.fit(TestModel())
434+
bar_updates = "".join(call.args[0] for call in mock_write.call_args_list)
435+
assert "a=0.123" in bar_updates
436+
assert "b=1.000" in bar_updates
437+
assert "c=2.000" in bar_updates
432438

433-
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
434-
assert trainer.progress_bar_metrics["b"] == 1.0
439+
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
440+
assert trainer.progress_bar_metrics["b"] == 1.0
435441
assert trainer.progress_bar_metrics["c"] == 2.0
436442
pbar = trainer.progress_bar_callback.train_progress_bar
437443
actual = str(pbar.postfix)

tests/tests_pytorch/trainer/connectors/test_rich_integration.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
from unittest.mock import patch
1616

17+
import pytest
18+
import torch
19+
1720
from lightning.pytorch import Trainer
1821
from lightning.pytorch.callbacks import ModelSummary, ProgressBar, RichModelSummary, RichProgressBar, TQDMProgressBar
22+
from lightning.pytorch.demos.boring_classes import BoringModel
1923

2024

2125
class TestRichIntegration:
@@ -133,3 +137,33 @@ def test_model_summary_disabled_with_rich(self, tmp_path):
133137
default_root_dir=tmp_path, enable_model_summary=False, logger=False, enable_checkpointing=False
134138
)
135139
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
140+
141+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True)
142+
def test_rich_progress_bar_tensor_metric(self, tmp_path):
143+
"""Test that tensor metrics are converted to float for RichProgressBar."""
144+
145+
class MyModel(BoringModel):
146+
def training_step(self, batch, batch_idx):
147+
self.log("my_tensor_metric", torch.tensor(1.23), prog_bar=True)
148+
return super().training_step(batch, batch_idx)
149+
150+
model = MyModel()
151+
trainer = Trainer(
152+
default_root_dir=tmp_path,
153+
limit_train_batches=1,
154+
limit_val_batches=0,
155+
max_epochs=1,
156+
logger=False,
157+
enable_checkpointing=False,
158+
)
159+
160+
with patch("lightning.pytorch.callbacks.progress.rich_progress.MetricsTextColumn.update") as mock_update:
161+
trainer.fit(model)
162+
163+
assert mock_update.call_count > 0
164+
# The metrics are updated multiple times, check the last call
165+
last_call_metrics = mock_update.call_args[0][0]
166+
assert "my_tensor_metric" in last_call_metrics
167+
metric_val = last_call_metrics["my_tensor_metric"]
168+
assert isinstance(metric_val, float)
169+
assert metric_val == pytest.approx(1.23)

0 commit comments

Comments
 (0)