Skip to content

Commit ed8f618

Browse files
committed
update pytorch deepspeed precision
1 parent 467b935 commit ed8f618

File tree

2 files changed

+97
-637
lines changed

2 files changed

+97
-637
lines changed

src/lightning/pytorch/plugins/precision/deepspeed.py

Lines changed: 22 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from contextlib import AbstractContextManager, nullcontext
15-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
14+
from contextlib import AbstractContextManager
15+
from typing import Any, Callable, Optional, Union
1616

17-
import torch
18-
from lightning_utilities import apply_to_collection
1917
from torch import Tensor
2018
from torch.nn import Module
21-
from torch.optim import LBFGS, Optimizer
22-
from typing_extensions import get_args, override
19+
from torch.optim import Optimizer
20+
from typing_extensions import override
2321

2422
import lightning.pytorch as pl
2523
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
26-
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
24+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2725
from lightning.fabric.utilities.types import Steppable
2826
from lightning.pytorch.plugins.precision.precision import Precision
2927
from lightning.pytorch.utilities import GradClipAlgorithmType
30-
from lightning.pytorch.utilities.exceptions import MisconfigurationException
31-
from lightning.pytorch.utilities.model_helpers import is_overridden
32-
from lightning.pytorch.utilities.rank_zero import WarningCache
33-
34-
if TYPE_CHECKING:
35-
import deepspeed
36-
37-
warning_cache = WarningCache()
3828

3929

4030
class DeepSpeedPrecision(Precision):
@@ -53,41 +43,29 @@ class DeepSpeedPrecision(Precision):
5343
"""
5444

5545
def __init__(self, precision: _PRECISION_INPUT) -> None:
56-
supported_precision = get_args(_PRECISION_INPUT)
57-
if precision not in supported_precision:
58-
raise ValueError(
59-
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
60-
f" `precision` must be one of: {supported_precision}."
61-
)
62-
self.precision = precision
63-
precision_to_type = {
64-
"bf16-mixed": torch.bfloat16,
65-
"16-mixed": torch.float16,
66-
"bf16-true": torch.bfloat16,
67-
"16-true": torch.float16,
68-
"32-true": torch.float32,
69-
}
70-
self._desired_dtype = precision_to_type[self.precision]
46+
super().__init__(precision)
47+
_raise_enterprise_not_available()
48+
from pytorch_lightning_enterprise.plugins.precision.deepspeed import (
49+
DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision,
50+
)
51+
52+
self.deepspeed_precision_impl = EnterpriseDeepSpeedPrecision(outer_object=self, precision=precision)
7153

7254
@override
7355
def convert_module(self, module: Module) -> Module:
74-
if "true" in self.precision:
75-
return module.to(dtype=self._desired_dtype)
76-
return module
56+
return self.deepspeed_precision_impl.convert_module(module=module)
7757

7858
@override
7959
def convert_input(self, data: Any) -> Any:
80-
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
60+
return self.deepspeed_precision_impl.convert_input(data=data)
8161

8262
@override
8363
def tensor_init_context(self) -> AbstractContextManager:
84-
if "true" not in self.precision:
85-
return nullcontext()
86-
return _DtypeContextManager(self._desired_dtype)
64+
return self.deepspeed_precision_impl.tensor_init_context()
8765

8866
@override
8967
def module_init_context(self) -> AbstractContextManager:
90-
return self.tensor_init_context()
68+
return self.deepspeed_precision_impl.module_init_context()
9169

9270
@override
9371
def backward( # type: ignore[override]
@@ -98,7 +76,7 @@ def backward( # type: ignore[override]
9876
*args: Any,
9977
**kwargs: Any,
10078
) -> None:
101-
r"""Performs back-propagation using DeepSpeed's engine.
79+
r"""Performs back-propagation.
10280
10381
Args:
10482
tensor: the loss tensor
@@ -108,13 +86,7 @@ def backward( # type: ignore[override]
10886
\**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
10987
11088
"""
111-
if is_overridden("backward", model):
112-
warning_cache.warn(
113-
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
114-
" the backward logic internally."
115-
)
116-
deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
117-
deepspeed_engine.backward(tensor, *args, **kwargs)
89+
return self.deepspeed_precision_impl.backward(tensor=tensor, model=model, optimizer=optimizer, *args, **kwargs)
11890

11991
@override
12092
def optimizer_step( # type: ignore[override]
@@ -124,19 +96,7 @@ def optimizer_step( # type: ignore[override]
12496
closure: Callable[[], Any],
12597
**kwargs: Any,
12698
) -> Any:
127-
if isinstance(optimizer, LBFGS):
128-
raise MisconfigurationException("DeepSpeed and the LBFGS optimizer are not compatible.")
129-
closure_result = closure()
130-
self._after_closure(model, optimizer)
131-
skipped_backward = closure_result is None
132-
# in manual optimization, the closure does not return a value
133-
if model.automatic_optimization and skipped_backward:
134-
raise MisconfigurationException(
135-
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
136-
)
137-
# DeepSpeed handles the optimizer step internally
138-
deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
139-
return deepspeed_engine.step(**kwargs)
99+
return self.deepspeed_precision_impl.optimizer_step(optimizer=optimizer, model=model, closure=closure, **kwargs)
140100

141101
@override
142102
def clip_gradients(
@@ -145,4 +105,6 @@ def clip_gradients(
145105
clip_val: Union[int, float] = 0.0,
146106
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
147107
) -> None:
148-
"""DeepSpeed handles gradient clipping internally."""
108+
return self.deepspeed_precision_impl.clip_gradients(
109+
optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm
110+
)

0 commit comments

Comments
 (0)