1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from contextlib import AbstractContextManager , nullcontext
15- from typing import TYPE_CHECKING , Any , Callable , Optional , Union
14+ from contextlib import AbstractContextManager
15+ from typing import Any , Callable , Optional , Union
1616
17- import torch
18- from lightning_utilities import apply_to_collection
1917from torch import Tensor
2018from torch .nn import Module
21- from torch .optim import LBFGS , Optimizer
22- from typing_extensions import get_args , override
19+ from torch .optim import Optimizer
20+ from typing_extensions import override
2321
2422import lightning .pytorch as pl
2523from lightning .fabric .plugins .precision .deepspeed import _PRECISION_INPUT
26- from lightning .fabric .plugins . precision . utils import _convert_fp_tensor , _DtypeContextManager
24+ from lightning .fabric .utilities . imports import _raise_enterprise_not_available
2725from lightning .fabric .utilities .types import Steppable
2826from lightning .pytorch .plugins .precision .precision import Precision
2927from lightning .pytorch .utilities import GradClipAlgorithmType
30- from lightning .pytorch .utilities .exceptions import MisconfigurationException
31- from lightning .pytorch .utilities .model_helpers import is_overridden
32- from lightning .pytorch .utilities .rank_zero import WarningCache
33-
34- if TYPE_CHECKING :
35- import deepspeed
36-
37- warning_cache = WarningCache ()
3828
3929
4030class DeepSpeedPrecision (Precision ):
@@ -53,41 +43,29 @@ class DeepSpeedPrecision(Precision):
5343 """
5444
5545 def __init__ (self , precision : _PRECISION_INPUT ) -> None :
56- supported_precision = get_args (_PRECISION_INPUT )
57- if precision not in supported_precision :
58- raise ValueError (
59- f"`Trainer(strategy='deepspeed', precision={ precision !r} )` is not supported."
60- f" `precision` must be one of: { supported_precision } ."
61- )
62- self .precision = precision
63- precision_to_type = {
64- "bf16-mixed" : torch .bfloat16 ,
65- "16-mixed" : torch .float16 ,
66- "bf16-true" : torch .bfloat16 ,
67- "16-true" : torch .float16 ,
68- "32-true" : torch .float32 ,
69- }
70- self ._desired_dtype = precision_to_type [self .precision ]
46+ super ().__init__ (precision )
47+ _raise_enterprise_not_available ()
48+ from pytorch_lightning_enterprise .plugins .precision .deepspeed import (
49+ DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision ,
50+ )
51+
52+ self .deepspeed_precision_impl = EnterpriseDeepSpeedPrecision (outer_object = self , precision = precision )
7153
7254 @override
7355 def convert_module (self , module : Module ) -> Module :
74- if "true" in self .precision :
75- return module .to (dtype = self ._desired_dtype )
76- return module
56+ return self .deepspeed_precision_impl .convert_module (module = module )
7757
7858 @override
7959 def convert_input (self , data : Any ) -> Any :
80- return apply_to_collection (data , function = _convert_fp_tensor , dtype = Tensor , dst_type = self . _desired_dtype )
60+ return self . deepspeed_precision_impl . convert_input (data = data )
8161
8262 @override
8363 def tensor_init_context (self ) -> AbstractContextManager :
84- if "true" not in self .precision :
85- return nullcontext ()
86- return _DtypeContextManager (self ._desired_dtype )
64+ return self .deepspeed_precision_impl .tensor_init_context ()
8765
8866 @override
8967 def module_init_context (self ) -> AbstractContextManager :
90- return self .tensor_init_context ()
68+ return self .deepspeed_precision_impl . module_init_context ()
9169
9270 @override
9371 def backward ( # type: ignore[override]
@@ -98,7 +76,7 @@ def backward( # type: ignore[override]
9876 * args : Any ,
9977 ** kwargs : Any ,
10078 ) -> None :
101- r"""Performs back-propagation using DeepSpeed's engine .
79+ r"""Performs back-propagation.
10280
10381 Args:
10482 tensor: the loss tensor
@@ -108,13 +86,7 @@ def backward( # type: ignore[override]
10886 \**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
10987
11088 """
111- if is_overridden ("backward" , model ):
112- warning_cache .warn (
113- "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
114- " the backward logic internally."
115- )
116- deepspeed_engine : deepspeed .DeepSpeedEngine = model .trainer .model
117- deepspeed_engine .backward (tensor , * args , ** kwargs )
89+ return self .deepspeed_precision_impl .backward (tensor = tensor , model = model , optimizer = optimizer , * args , ** kwargs )
11890
11991 @override
12092 def optimizer_step ( # type: ignore[override]
@@ -124,19 +96,7 @@ def optimizer_step( # type: ignore[override]
12496 closure : Callable [[], Any ],
12597 ** kwargs : Any ,
12698 ) -> Any :
127- if isinstance (optimizer , LBFGS ):
128- raise MisconfigurationException ("DeepSpeed and the LBFGS optimizer are not compatible." )
129- closure_result = closure ()
130- self ._after_closure (model , optimizer )
131- skipped_backward = closure_result is None
132- # in manual optimization, the closure does not return a value
133- if model .automatic_optimization and skipped_backward :
134- raise MisconfigurationException (
135- "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
136- )
137- # DeepSpeed handles the optimizer step internally
138- deepspeed_engine : deepspeed .DeepSpeedEngine = model .trainer .model
139- return deepspeed_engine .step (** kwargs )
99+ return self .deepspeed_precision_impl .optimizer_step (optimizer = optimizer , model = model , closure = closure , ** kwargs )
140100
141101 @override
142102 def clip_gradients (
@@ -145,4 +105,6 @@ def clip_gradients(
145105 clip_val : Union [int , float ] = 0.0 ,
146106 gradient_clip_algorithm : GradClipAlgorithmType = GradClipAlgorithmType .NORM ,
147107 ) -> None :
148- """DeepSpeed handles gradient clipping internally."""
108+ return self .deepspeed_precision_impl .clip_gradients (
109+ optimizer = optimizer , clip_val = clip_val , gradient_clip_algorithm = gradient_clip_algorithm
110+ )
0 commit comments