Skip to content

Commit f5badec

Browse files
committed
fix mis-alignment column while using rich model summary in DeepSpeed strategy.
1 parent ad54bc1 commit f5badec

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,21 @@ def summarize(
7979
from rich.table import Table
8080

8181
console = get_console()
82+
column_names = list(zip(*summary_data))[0]
8283

8384
header_style: str = summarize_kwargs.get("header_style", "bold magenta")
8485
table = Table(header_style=header_style)
8586
table.add_column(" ", style="dim")
8687
table.add_column("Name", justify="left", no_wrap=True)
8788
table.add_column("Type")
8889
table.add_column("Params", justify="right")
90+
91+
if "Params per Device" in column_names:
92+
table.add_column("Params per Device", justify="right")
93+
8994
table.add_column("Mode")
9095
table.add_column("FLOPs", justify="right")
9196

92-
column_names = list(zip(*summary_data))[0]
93-
9497
for column_name in ["In sizes", "Out sizes"]:
9598
if column_name in column_names:
9699
table.add_column(column_name, justify="right", style="white")

src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
9999
("Params", list(map(get_human_readable_count, self.param_nums))),
100100
("Params per Device", list(map(get_human_readable_count, self.parameters_per_layer))),
101101
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
102+
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
102103
]
103104
if self._model.example_input_array is not None:
104105
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))

tests/tests_pytorch/utilities/test_deepspeed_model_summary.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
17+
import torch
18+
1519
import lightning.pytorch as pl
1620
from lightning.pytorch import Callback, Trainer
1721
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -51,3 +55,38 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
5155
)
5256

5357
trainer.fit(model)
58+
59+
60+
@RunIf(deepspeed=True, rich=True)
61+
@mock.patch("rich.table.Table.add_row", autospec=True)
62+
def test_deepspeed_summary_with_rich_model_summary(mock_table_add_row, tmp_path):
63+
from lightning.pytorch.callbacks import RichModelSummary
64+
65+
model = BoringModel()
66+
model.example_input_array = torch.randn(4, 32)
67+
68+
trainer = Trainer(
69+
strategy=DeepSpeedStrategy(stage=3),
70+
default_root_dir=tmp_path,
71+
accelerator="gpu",
72+
fast_dev_run=True,
73+
devices=1,
74+
enable_model_summary=True,
75+
callbacks=[RichModelSummary()],
76+
)
77+
78+
trainer.fit(model)
79+
80+
# assert that the input summary data was converted correctly
81+
args, _ = mock_table_add_row.call_args_list[0]
82+
assert args[1:] == (
83+
"0",
84+
"layer",
85+
"Linear",
86+
"66 ",
87+
"66 ",
88+
"train",
89+
"512 ",
90+
"[4, 32]",
91+
"[4, 2]",
92+
)

0 commit comments

Comments
 (0)