Skip to content

Commit b438fa5

Browse files
awaelchlilexierule
authored andcommitted
Cast to fp16 before moving to device with deepspeed (#14000)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 0bdbf4d commit b438fa5

File tree

8 files changed

+69
-22
lines changed

8 files changed

+69
-22
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Fixed
1010

1111
- Casted only floating point tensors to fp16 with IPUs ([#13983](https://github.com/Lightning-AI/lightning/pull/13983))
12-
13-
12+
- Casted tensors to fp16 before moving them to device with `DeepSpeedStrategy` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
1413
- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988))
1514
- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262))
1615
- Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992))

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3434
from pytorch_lightning.plugins.precision import PrecisionPlugin
3535
from pytorch_lightning.strategies.ddp import DDPStrategy
36+
from pytorch_lightning.strategies.utils import _fp_to_half
3637
from pytorch_lightning.trainer.states import TrainerFn
3738
from pytorch_lightning.utilities import GradClipAlgorithmType
3839
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -46,10 +47,10 @@
4647
from pytorch_lightning.utilities.imports import _RequirementAvailable
4748
from pytorch_lightning.utilities.model_helpers import is_overridden
4849
from pytorch_lightning.utilities.optimizer import optimizers_to_device
49-
from pytorch_lightning.utilities.rank_zero import rank_zero_info
50+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
5051
from pytorch_lightning.utilities.seed import reset_seed
5152
from pytorch_lightning.utilities.types import _LRScheduler, _PATH, LRSchedulerConfig, ReduceLROnPlateau, STEP_OUTPUT
52-
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
53+
from pytorch_lightning.utilities.warnings import WarningCache
5354

5455
warning_cache = WarningCache()
5556

@@ -70,9 +71,15 @@ def remove_module_hooks(model: torch.nn.Module) -> None:
7071

7172

7273
class LightningDeepSpeedModule(_LightningModuleWrapperBase):
74+
"""
75+
.. deprecated:: v1.7.1
76+
``LightningDeepSpeedModule`` has been deprecated in v1.7.1 and will be removed in v1.9.0.
77+
"""
78+
7379
def __init__(
7480
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int]
7581
) -> None:
82+
rank_zero_deprecation("`LightningDeepSpeedModule` has been deprecated in v1.7.1 and will be removed in v1.9.0")
7683
super().__init__(pl_module)
7784
self.precision = precision
7885

@@ -477,7 +484,7 @@ def init_deepspeed(self) -> None:
477484
)
478485

479486
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
480-
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision)
487+
model = _LightningModuleWrapperBase(pl_module=self.model)
481488

482489
if self.lightning_module.trainer and self.lightning_module.trainer.training:
483490
self._initialize_deepspeed_train(model)
@@ -605,9 +612,9 @@ def _initialize_deepspeed_inference(self, model: Module) -> None:
605612

606613
@property
607614
def lightning_module(self) -> Optional["pl.LightningModule"]:
608-
# the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early
615+
# the model may not be wrapped with DeepEngine & _LightningModuleWrapperBase if calling this too early
609616
module = getattr(self.model, "module", self.model)
610-
module = module.module if isinstance(module, LightningDeepSpeedModule) else module
617+
module = module.module if isinstance(module, _LightningModuleWrapperBase) else module
611618
assert isinstance(module, pl.LightningModule) or module is None
612619
return module
613620

@@ -943,6 +950,10 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
943950
offload_optimizer_device="nvme",
944951
)
945952

953+
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
954+
batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision)
955+
return super().batch_to_device(batch, device, dataloader_idx)
956+
946957
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
947958
assert self.model is not None
948959
with self.precision_plugin.val_step_context():

src/pytorch_lightning/strategies/ipu.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2626
from pytorch_lightning.plugins.precision import PrecisionPlugin
2727
from pytorch_lightning.strategies.parallel import ParallelStrategy
28+
from pytorch_lightning.strategies.utils import _fp_to_half
2829
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
2930
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
3031
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -43,6 +44,11 @@
4344

4445

4546
class LightningIPUModule(_LightningModuleWrapperBase):
47+
"""
48+
.. deprecated:: v1.7.0
49+
``LightningIPUModule`` has been deprecated in v1.7.0 and will be removed in v1.9.0.
50+
"""
51+
4652
def __init__(
4753
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int]
4854
) -> None:
@@ -274,13 +280,7 @@ def to_tensor(x):
274280
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
275281
# This override is necessary because the cast must occur before the data
276282
# is moved to the device to prevent wasteful host->device copies.
277-
def fp_to_half(tensor: Tensor) -> Tensor:
278-
if torch.is_floating_point(tensor):
279-
return tensor.half()
280-
return tensor
281-
282-
if self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF):
283-
batch = apply_to_collection(batch, Tensor, function=fp_to_half)
283+
batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision)
284284
# We don't call `super().batch_to_device` because `data.to(device)` is not
285285
# currently necessary for IPUs. The movement of data from host<->IPU is
286286
# currently handled by PopTorch.

src/pytorch_lightning/strategies/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# limitations under the License.
1414
import os
1515

16+
import torch
17+
18+
from pytorch_lightning.utilities.enums import PrecisionType
19+
1620

1721
def on_colab_kaggle() -> bool:
1822
return bool(os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE"))
23+
24+
25+
def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
26+
if torch.is_floating_point(tensor):
27+
if precision in (PrecisionType.MIXED, PrecisionType.HALF):
28+
return tensor.half()
29+
if precision == PrecisionType.BFLOAT:
30+
return tensor.bfloat16()
31+
32+
return tensor

src/pytorch_lightning/utilities/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(
9898
model_file = get_model_state_file(checkpoint_dir, zero_stage)
9999
client_state = torch.load(model_file, map_location=CPU_DEVICE)
100100
client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
101-
# State dict keys will include reference to wrapper LightningDeepSpeedModule
101+
# State dict keys will include reference to wrapper _LightningModuleWrapperBase
102102
# Delete `module` prefix before saving.
103103
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
104104
client_state["state_dict"] = state_dict

tests/tests_pytorch/deprecated_api/test_remove_1-8.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,9 +1007,7 @@ def test_trainer_config_ipus(monkeypatch, trainer_kwargs, expected_ipus):
10071007
trainer.ipus == expected_ipus
10081008

10091009

1010-
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
1011-
def test_v1_8_0_deprecated_lightning_ipu_module(_, monkeypatch):
1012-
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", True)
1010+
def test_v1_8_0_deprecated_lightning_ipu_module():
10131011
with pytest.deprecated_call(match=r"has been deprecated in v1.7.0 and will be removed in v1.8."):
10141012
_ = LightningIPUModule(BoringModel(), 32)
10151013

tests/tests_pytorch/deprecated_api/test_remove_1-9.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytorch_lightning.profiler.pytorch import PyTorchProfiler, RegisterRecordFunction, ScheduleWrapper
3131
from pytorch_lightning.profiler.simple import SimpleProfiler
3232
from pytorch_lightning.profiler.xla import XLAProfiler
33+
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
3334
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
3435
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3536
from tests_pytorch.helpers.runif import RunIf
@@ -217,3 +218,8 @@ def test_gpu_accelerator_deprecation_warning():
217218
)
218219
):
219220
GPUAccelerator()
221+
222+
223+
def test_v1_9_0_deprecated_lightning_deepspeed_module():
224+
with pytest.deprecated_call(match=r"has been deprecated in v1.7.1 and will be removed in v1.9."):
225+
_ = LightningDeepSpeedModule(BoringModel(), 32)

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,12 @@ def automatic_optimization(self) -> bool:
8585
return False
8686

8787

88-
def test_deepspeed_lightning_module(tmpdir):
88+
def test_deepspeed_lightning_module():
8989
"""Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly."""
9090

9191
model = BoringModel()
92-
module = LightningDeepSpeedModule(model, precision=16)
92+
with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"):
93+
module = LightningDeepSpeedModule(model, precision=16)
9394

9495
module.half()
9596
assert module.dtype == torch.half
@@ -101,12 +102,13 @@ def test_deepspeed_lightning_module(tmpdir):
101102

102103

103104
@RunIf(min_cuda_gpus=1)
104-
def test_deepspeed_lightning_module_precision(tmpdir):
105+
def test_deepspeed_lightning_module_precision():
105106
"""Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision
106107
16."""
107108

108109
model = BoringModel()
109-
module = LightningDeepSpeedModule(model, precision=16)
110+
with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"):
111+
module = LightningDeepSpeedModule(model, precision=16)
110112

111113
module.cuda().half()
112114
assert module.dtype == torch.half
@@ -1306,6 +1308,7 @@ def test_deepspeed_with_bfloat16_precision(tmpdir):
13061308
assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin)
13071309
assert trainer.strategy.precision_plugin.precision == "bf16"
13081310
assert trainer.strategy.config["zero_optimization"]["stage"] == 3
1311+
assert trainer.strategy.config["bf16"]["enabled"]
13091312
assert model.layer.weight.dtype == torch.bfloat16
13101313

13111314

@@ -1344,3 +1347,19 @@ def configure_optimizers(self):
13441347
)
13451348
with pytest.raises(SystemExit):
13461349
trainer.fit(model)
1350+
1351+
1352+
@RunIf(min_cuda_gpus=1, deepspeed=True)
1353+
def test_deepspeed_tensors_cast_to_fp16_before_hosted_on_device():
1354+
class CustomBoringModel(BoringModel):
1355+
def transfer_batch_to_device(self, batch, *args, **kwargs):
1356+
assert batch.dtype is torch.float16
1357+
return super().transfer_batch_to_device(batch, *args, **kwargs)
1358+
1359+
model = CustomBoringModel()
1360+
trainer = Trainer(strategy="deepspeed", devices=1, accelerator="cuda", precision=16)
1361+
trainer.strategy.connect(model)
1362+
batch = torch.zeros((1), dtype=torch.float32)
1363+
batch = trainer.strategy.batch_to_device(batch)
1364+
assert batch.is_cuda
1365+
assert batch.dtype is torch.float16

0 commit comments

Comments
 (0)