Skip to content

Commit 61f01c6

Browse files
committed
forward transformer engine
1 parent 7dbab8e commit 61f01c6

File tree

1 file changed

+52
-104
lines changed

1 file changed

+52
-104
lines changed

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 52 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import logging
1514
from collections.abc import Mapping
16-
from contextlib import AbstractContextManager, ExitStack
15+
from contextlib import AbstractContextManager
1716
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
1817

1918
import torch
20-
from lightning_utilities import apply_to_collection
21-
from lightning_utilities.core.imports import RequirementCache
22-
from torch import Tensor
2319
from typing_extensions import override
2420

2521
from lightning.fabric.plugins.precision.precision import Precision
26-
from lightning.fabric.plugins.precision.utils import (
27-
_ClassReplacementContextManager,
28-
_convert_fp_tensor,
29-
_DtypeContextManager,
30-
)
31-
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
22+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
3223

3324
if TYPE_CHECKING:
3425
from transformer_engine.common.recipe import DelayedScaling
3526

36-
_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
37-
log = logging.getLogger(__name__)
38-
3927

4028
class TransformerEnginePrecision(Precision):
4129
"""Plugin for training with fp8 precision via nvidia's
@@ -72,111 +60,71 @@ def __init__(
7260
replace_layers: Optional[bool] = None,
7361
fallback_compute_dtype: Optional[torch.dtype] = None,
7462
) -> None:
75-
if not _TRANSFORMER_ENGINE_AVAILABLE:
76-
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
77-
from transformer_engine.common.recipe import DelayedScaling
78-
79-
if recipe is None:
80-
recipe = DelayedScaling()
81-
elif isinstance(recipe, Mapping):
82-
recipe = dict(recipe) # copy
83-
if "fp8_format" in recipe:
84-
from transformer_engine.common.recipe import Format
85-
86-
recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
87-
recipe = DelayedScaling(**recipe)
88-
89-
self.weights_dtype = weights_dtype
90-
self.recipe = recipe
91-
self.replace_layers = replace_layers
92-
self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
63+
super().__init__()
64+
_raise_enterprise_not_available()
65+
from pytorch_lightning_enterprise.fabric.plugins.precision.transformer_engine import (
66+
TransformerEnginePrecision as EnterpriseTransformerEnginePrecision,
67+
)
68+
69+
self.transformer_engine_impl = EnterpriseTransformerEnginePrecision(
70+
weights_dtype=weights_dtype,
71+
recipe=recipe,
72+
replace_layers=replace_layers,
73+
fallback_compute_dtype=fallback_compute_dtype,
74+
)
75+
76+
@property
77+
def weights_dtype(self) -> torch.dtype:
78+
return self.transformer_engine_impl.weights_dtype
79+
80+
@weights_dtype.setter
81+
def weights_dtype(self, value: torch.dtype) -> None:
82+
self.transformer_engine_impl.weights_dtype = value
83+
84+
@property
85+
def recipe(self) -> Union[Mapping[str, Any], "DelayedScaling"]:
86+
return self.transformer_engine_impl.recipe
87+
88+
@recipe.setter
89+
def recipe(self, value: Union[Mapping[str, Any], "DelayedScaling"]) -> None:
90+
self.transformer_engine_impl.recipe = value
91+
92+
@property
93+
def replace_layers(self) -> bool:
94+
return self.transformer_engine_impl.replace_layers
95+
96+
@replace_layers.setter
97+
def replace_layers(self, value: bool) -> None:
98+
self.transformer_engine_impl.replace_layers = value
99+
100+
@property
101+
def fallback_compute_dtype(self) -> torch.dtype:
102+
return self.transformer_engine_impl.fallback_compute_dtype
103+
104+
@fallback_compute_dtype.setter
105+
def fallback_compute_dtype(self, value: torch.dtype) -> None:
106+
self.transformer_engine_impl.fallback_compute_dtype = value
93107

94108
@override
95109
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
96-
# avoid converting if any is found. assume the user took care of it
97-
if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
98-
if self.replace_layers is True:
99-
# info level because this is expected with `init_module`
100-
rank_zero_info(
101-
"`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains"
102-
" TransformerEngine layers. Skipping"
103-
)
104-
elif self.replace_layers in (None, True):
105-
_convert_layers(module)
106-
module = module.to(dtype=self.weights_dtype)
107-
return module
110+
return self.transformer_engine_impl.convert_module(module)
108111

109112
@override
110113
def tensor_init_context(self) -> AbstractContextManager:
111-
return _DtypeContextManager(self.weights_dtype)
114+
return self.transformer_engine_impl.tensor_init_context()
112115

113116
@override
114117
def module_init_context(self) -> AbstractContextManager:
115-
dtype_ctx = self.tensor_init_context()
116-
stack = ExitStack()
117-
if self.replace_layers:
118-
import transformer_engine.pytorch as te
119-
120-
context_manager = _ClassReplacementContextManager({
121-
"torch.nn.Linear": te.Linear,
122-
"torch.nn.LayerNorm": te.LayerNorm,
123-
})
124-
stack.enter_context(context_manager)
125-
stack.enter_context(dtype_ctx)
126-
return stack
118+
return self.transformer_engine_impl.module_init_context()
127119

128120
@override
129121
def forward_context(self) -> AbstractContextManager:
130-
dtype_ctx = _DtypeContextManager(self.weights_dtype)
131-
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
132-
import transformer_engine.pytorch as te
133-
134-
autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
135-
stack = ExitStack()
136-
stack.enter_context(dtype_ctx)
137-
# enable an outer fallback autocast for operations that do not support fp8
138-
stack.enter_context(fallback_autocast_ctx)
139-
stack.enter_context(autocast_ctx)
140-
return stack
122+
return self.transformer_engine_impl.forward_context()
141123

142124
@override
143125
def convert_input(self, data: Any) -> Any:
144-
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
126+
return self.transformer_engine_impl.convert_input(data)
145127

146128
@override
147129
def convert_output(self, data: Any) -> Any:
148-
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
149-
150-
151-
def _convert_layers(module: torch.nn.Module) -> None:
152-
import transformer_engine.pytorch as te
153-
154-
for name, child in module.named_children():
155-
if isinstance(child, torch.nn.Linear):
156-
if child.in_features % 8 != 0 or child.out_features % 16 != 0:
157-
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
158-
rank_zero_warn(
159-
"Support for FP8 in the linear layers with this plugin is currently limited to"
160-
" tensors with shapes where the dimensions are divisible by 8 and 16 respectively."
161-
f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs."
162-
)
163-
continue
164-
has_bias = child.bias is not None
165-
replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
166-
replacement.weight.data = child.weight.data.clone()
167-
if has_bias:
168-
replacement.bias.data = child.bias.data.clone()
169-
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
170-
module.__setattr__(name, replacement)
171-
elif isinstance(child, torch.nn.LayerNorm):
172-
replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
173-
replacement.weight.data = child.weight.data.clone()
174-
# Check if bias exists before attempting to clone its data
175-
if child.bias is not None and replacement.bias is not None:
176-
replacement.bias.data = child.bias.data.clone()
177-
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
178-
module.__setattr__(name, replacement)
179-
else:
180-
# there are other transformer engine layers that we could convert but require fusion. full list at:
181-
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html
182-
_convert_layers(child)
130+
return self.transformer_engine_impl.convert_output(data)

0 commit comments

Comments
 (0)