|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -import logging |
15 | 14 | from collections.abc import Mapping |
16 | | -from contextlib import AbstractContextManager, ExitStack |
| 15 | +from contextlib import AbstractContextManager |
17 | 16 | from typing import TYPE_CHECKING, Any, Literal, Optional, Union |
18 | 17 |
|
19 | 18 | import torch |
20 | | -from lightning_utilities import apply_to_collection |
21 | | -from lightning_utilities.core.imports import RequirementCache |
22 | | -from torch import Tensor |
23 | 19 | from typing_extensions import override |
24 | 20 |
|
25 | 21 | from lightning.fabric.plugins.precision.precision import Precision |
26 | | -from lightning.fabric.plugins.precision.utils import ( |
27 | | - _ClassReplacementContextManager, |
28 | | - _convert_fp_tensor, |
29 | | - _DtypeContextManager, |
30 | | -) |
31 | | -from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn |
| 22 | +from lightning.fabric.utilities.imports import _raise_enterprise_not_available |
32 | 23 |
|
33 | 24 | if TYPE_CHECKING: |
34 | 25 | from transformer_engine.common.recipe import DelayedScaling |
35 | 26 |
|
36 | | -_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0") |
37 | | -log = logging.getLogger(__name__) |
38 | | - |
39 | 27 |
|
40 | 28 | class TransformerEnginePrecision(Precision): |
41 | 29 | """Plugin for training with fp8 precision via nvidia's |
@@ -72,111 +60,71 @@ def __init__( |
72 | 60 | replace_layers: Optional[bool] = None, |
73 | 61 | fallback_compute_dtype: Optional[torch.dtype] = None, |
74 | 62 | ) -> None: |
75 | | - if not _TRANSFORMER_ENGINE_AVAILABLE: |
76 | | - raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) |
77 | | - from transformer_engine.common.recipe import DelayedScaling |
78 | | - |
79 | | - if recipe is None: |
80 | | - recipe = DelayedScaling() |
81 | | - elif isinstance(recipe, Mapping): |
82 | | - recipe = dict(recipe) # copy |
83 | | - if "fp8_format" in recipe: |
84 | | - from transformer_engine.common.recipe import Format |
85 | | - |
86 | | - recipe["fp8_format"] = getattr(Format, recipe["fp8_format"]) |
87 | | - recipe = DelayedScaling(**recipe) |
88 | | - |
89 | | - self.weights_dtype = weights_dtype |
90 | | - self.recipe = recipe |
91 | | - self.replace_layers = replace_layers |
92 | | - self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype |
| 63 | + super().__init__() |
| 64 | + _raise_enterprise_not_available() |
| 65 | + from pytorch_lightning_enterprise.fabric.plugins.precision.transformer_engine import ( |
| 66 | + TransformerEnginePrecision as EnterpriseTransformerEnginePrecision, |
| 67 | + ) |
| 68 | + |
| 69 | + self.transformer_engine_impl = EnterpriseTransformerEnginePrecision( |
| 70 | + weights_dtype=weights_dtype, |
| 71 | + recipe=recipe, |
| 72 | + replace_layers=replace_layers, |
| 73 | + fallback_compute_dtype=fallback_compute_dtype, |
| 74 | + ) |
| 75 | + |
| 76 | + @property |
| 77 | + def weights_dtype(self) -> torch.dtype: |
| 78 | + return self.transformer_engine_impl.weights_dtype |
| 79 | + |
| 80 | + @weights_dtype.setter |
| 81 | + def weights_dtype(self, value: torch.dtype) -> None: |
| 82 | + self.transformer_engine_impl.weights_dtype = value |
| 83 | + |
| 84 | + @property |
| 85 | + def recipe(self) -> Union[Mapping[str, Any], "DelayedScaling"]: |
| 86 | + return self.transformer_engine_impl.recipe |
| 87 | + |
| 88 | + @recipe.setter |
| 89 | + def recipe(self, value: Union[Mapping[str, Any], "DelayedScaling"]) -> None: |
| 90 | + self.transformer_engine_impl.recipe = value |
| 91 | + |
| 92 | + @property |
| 93 | + def replace_layers(self) -> bool: |
| 94 | + return self.transformer_engine_impl.replace_layers |
| 95 | + |
| 96 | + @replace_layers.setter |
| 97 | + def replace_layers(self, value: bool) -> None: |
| 98 | + self.transformer_engine_impl.replace_layers = value |
| 99 | + |
| 100 | + @property |
| 101 | + def fallback_compute_dtype(self) -> torch.dtype: |
| 102 | + return self.transformer_engine_impl.fallback_compute_dtype |
| 103 | + |
| 104 | + @fallback_compute_dtype.setter |
| 105 | + def fallback_compute_dtype(self, value: torch.dtype) -> None: |
| 106 | + self.transformer_engine_impl.fallback_compute_dtype = value |
93 | 107 |
|
94 | 108 | @override |
95 | 109 | def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: |
96 | | - # avoid converting if any is found. assume the user took care of it |
97 | | - if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): |
98 | | - if self.replace_layers is True: |
99 | | - # info level because this is expected with `init_module` |
100 | | - rank_zero_info( |
101 | | - "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains" |
102 | | - " TransformerEngine layers. Skipping" |
103 | | - ) |
104 | | - elif self.replace_layers in (None, True): |
105 | | - _convert_layers(module) |
106 | | - module = module.to(dtype=self.weights_dtype) |
107 | | - return module |
| 110 | + return self.transformer_engine_impl.convert_module(module) |
108 | 111 |
|
109 | 112 | @override |
110 | 113 | def tensor_init_context(self) -> AbstractContextManager: |
111 | | - return _DtypeContextManager(self.weights_dtype) |
| 114 | + return self.transformer_engine_impl.tensor_init_context() |
112 | 115 |
|
113 | 116 | @override |
114 | 117 | def module_init_context(self) -> AbstractContextManager: |
115 | | - dtype_ctx = self.tensor_init_context() |
116 | | - stack = ExitStack() |
117 | | - if self.replace_layers: |
118 | | - import transformer_engine.pytorch as te |
119 | | - |
120 | | - context_manager = _ClassReplacementContextManager({ |
121 | | - "torch.nn.Linear": te.Linear, |
122 | | - "torch.nn.LayerNorm": te.LayerNorm, |
123 | | - }) |
124 | | - stack.enter_context(context_manager) |
125 | | - stack.enter_context(dtype_ctx) |
126 | | - return stack |
| 118 | + return self.transformer_engine_impl.module_init_context() |
127 | 119 |
|
128 | 120 | @override |
129 | 121 | def forward_context(self) -> AbstractContextManager: |
130 | | - dtype_ctx = _DtypeContextManager(self.weights_dtype) |
131 | | - fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) |
132 | | - import transformer_engine.pytorch as te |
133 | | - |
134 | | - autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe) |
135 | | - stack = ExitStack() |
136 | | - stack.enter_context(dtype_ctx) |
137 | | - # enable an outer fallback autocast for operations that do not support fp8 |
138 | | - stack.enter_context(fallback_autocast_ctx) |
139 | | - stack.enter_context(autocast_ctx) |
140 | | - return stack |
| 122 | + return self.transformer_engine_impl.forward_context() |
141 | 123 |
|
142 | 124 | @override |
143 | 125 | def convert_input(self, data: Any) -> Any: |
144 | | - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype) |
| 126 | + return self.transformer_engine_impl.convert_input(data) |
145 | 127 |
|
146 | 128 | @override |
147 | 129 | def convert_output(self, data: Any) -> Any: |
148 | | - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) |
149 | | - |
150 | | - |
151 | | -def _convert_layers(module: torch.nn.Module) -> None: |
152 | | - import transformer_engine.pytorch as te |
153 | | - |
154 | | - for name, child in module.named_children(): |
155 | | - if isinstance(child, torch.nn.Linear): |
156 | | - if child.in_features % 8 != 0 or child.out_features % 16 != 0: |
157 | | - # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting |
158 | | - rank_zero_warn( |
159 | | - "Support for FP8 in the linear layers with this plugin is currently limited to" |
160 | | - " tensors with shapes where the dimensions are divisible by 8 and 16 respectively." |
161 | | - f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs." |
162 | | - ) |
163 | | - continue |
164 | | - has_bias = child.bias is not None |
165 | | - replacement = te.Linear(child.in_features, child.out_features, bias=has_bias) |
166 | | - replacement.weight.data = child.weight.data.clone() |
167 | | - if has_bias: |
168 | | - replacement.bias.data = child.bias.data.clone() |
169 | | - log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") |
170 | | - module.__setattr__(name, replacement) |
171 | | - elif isinstance(child, torch.nn.LayerNorm): |
172 | | - replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps) |
173 | | - replacement.weight.data = child.weight.data.clone() |
174 | | - # Check if bias exists before attempting to clone its data |
175 | | - if child.bias is not None and replacement.bias is not None: |
176 | | - replacement.bias.data = child.bias.data.clone() |
177 | | - log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") |
178 | | - module.__setattr__(name, replacement) |
179 | | - else: |
180 | | - # there are other transformer engine layers that we could convert but require fusion. full list at: |
181 | | - # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html |
182 | | - _convert_layers(child) |
| 130 | + return self.transformer_engine_impl.convert_output(data) |
0 commit comments