Skip to content

Commit ec9fd02

Browse files
colehawkinscarmoccapre-commit-ci[bot]awaelchli
authored andcommitted
Add check for bf16 in deepspeed inference (#16973)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Cole Hawkins <colehawk> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: awaelchli <[email protected]> (cherry picked from commit c271d4c)
1 parent 0d33813 commit ec9fd02

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
- Fixed `num_nodes` not being set for `DDPFullyShardedNativeStrategy` ([#17160](https://github.com/Lightning-AI/lightning/pull/17160))
2828

29+
- Fixed parsing the precision config for inference in `DeepSpeedStrategy` ([#16973](https://github.com/Lightning-AI/lightning/pull/16973))
30+
31+
2932
- Fixed the availability check for `rich` that prevented Lightning to be imported in Google Colab ([#17156](https://github.com/Lightning-AI/lightning/pull/17156))
3033

3134

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ def _initialize_deepspeed_inference(self, model: Module) -> None:
553553
inference_config = {"train_micro_batch_size_per_gpu": 1}
554554
if "fp16" in self.config:
555555
inference_config.update({"fp16": self.config["fp16"]})
556+
if "bf16" in self.config:
557+
inference_config.update({"bf16": self.config["bf16"]})
556558
if self.zero_stage_3:
557559
inference_config.update(
558560
{

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,32 @@ def on_train_start(self, trainer, pl_module) -> None:
371371
trainer.fit(model)
372372

373373

374+
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
375+
@pytest.mark.parametrize("precision", ["fp16", "bf16"])
376+
def test_deepspeed_inference_precision_during_inference(precision, tmpdir):
377+
"""Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains
378+
these changes."""
379+
380+
class TestCB(Callback):
381+
def on_validation_start(self, trainer, pl_module) -> None:
382+
assert trainer.strategy.config[precision]
383+
raise SystemExit()
384+
385+
model = BoringModel()
386+
strategy = DeepSpeedStrategy(config={precision: {"enabled": True}})
387+
388+
trainer = Trainer(
389+
default_root_dir=tmpdir,
390+
strategy=strategy,
391+
accelerator="cuda",
392+
devices=1,
393+
callbacks=[TestCB()],
394+
barebones=True,
395+
)
396+
with pytest.raises(SystemExit):
397+
trainer.validate(model)
398+
399+
374400
@RunIf(deepspeed=True)
375401
def test_deepspeed_custom_activation_checkpointing_params(tmpdir):
376402
"""Ensure if we modify the activation checkpointing parameters, the deepspeed config contains these changes."""

0 commit comments

Comments
 (0)