Skip to content

Commit 97469c6

Browse files
carmoccaawaelchli
andauthored
TransformerEngine fallback compute dtype (#19082)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent d8b6bbd commit 97469c6

File tree

10 files changed

+95
-29
lines changed

10 files changed

+95
-29
lines changed

docs/source-fabric/fundamentals/precision.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ the model and inputs can be kept in true full or half precision.
158158
from lightning.fabric.plugins import TransformerEnginePrecision
159159
160160
recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"}
161-
precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe)
161+
precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe)
162162
fabric = Fabric(plugins=precision)
163163
164164

docs/source-pytorch/common/precision_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ the model and inputs can be kept in true full or half precision.
147147
from lightning.trainer.plugins import TransformerEnginePrecision
148148
149149
recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"}
150-
precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe)
150+
precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe)
151151
trainer = Trainer(plugins=precision)
152152
153153

src/lightning/fabric/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `lightning.fabric.utilities.AttributeDict` for convenient dict-attribute access to represent state in script ([#18943](https://github.com/Lightning-AI/lightning/pull/18943))
1616

1717

18+
- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
19+
20+
1821
### Changed
1922

2023
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
2124

2225

26+
- Changed the `TransformerEnginePrecision(dtype=)` argument to `weights_dtype` and made it required ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
27+
28+
2329
### Deprecated
2430

2531
-
@@ -38,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3844
- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))
3945

4046

47+
- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
48+
49+
4150

4251
## [2.1.2] - 2023-11-15
4352

src/lightning/fabric/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,9 @@ def _check_and_init_precision(self) -> Precision:
469469
if self._precision_input == "64-true":
470470
return DoublePrecision()
471471
if self._precision_input == "transformer-engine":
472-
return TransformerEnginePrecision(dtype=torch.bfloat16)
472+
return TransformerEnginePrecision(weights_dtype=torch.bfloat16)
473473
if self._precision_input == "transformer-engine-float16":
474-
return TransformerEnginePrecision(dtype=torch.float16)
474+
return TransformerEnginePrecision(weights_dtype=torch.float16)
475475

476476
if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
477477
rank_zero_warn(

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
_convert_fp_tensor,
2727
_DtypeContextManager,
2828
)
29-
from lightning.fabric.utilities.rank_zero import rank_zero_warn
29+
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
3030

3131
if TYPE_CHECKING:
3232
from transformer_engine.common.recipe import DelayedScaling
@@ -42,13 +42,15 @@ class TransformerEnginePrecision(Precision):
4242
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
4343
4444
Args:
45-
dtype: The weights dtype to use.
45+
weights_dtype: The weights dtype to use.
4646
recipe: Recipe for the DelayedScaling
4747
`configuration <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling>`__.
4848
In dict format or the dataclass format.
4949
replace_layers: Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their Transformer
5050
Engine alternatives. Note that they don't subclass the torch equivalents so checks like
5151
``isinstance(l, torch.nn.Linear)`` will not pass.
52+
fallback_compute_dtype: The compute dtype to use for operations that don't support fp8 autocast. Defaults to the
53+
same as ``weights_dtype``.
5254
5355
.. note::
5456
@@ -62,9 +64,11 @@ class TransformerEnginePrecision(Precision):
6264

6365
def __init__(
6466
self,
65-
dtype: Optional[torch.dtype] = None,
67+
*,
68+
weights_dtype: torch.dtype,
6669
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
6770
replace_layers: Optional[bool] = None,
71+
fallback_compute_dtype: Optional[torch.dtype] = None,
6872
) -> None:
6973
if not _TRANSFORMER_ENGINE_AVAILABLE:
7074
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
@@ -80,21 +84,27 @@ def __init__(
8084
recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
8185
recipe = DelayedScaling(**recipe)
8286

83-
if dtype is None:
84-
dtype = torch.get_default_dtype()
85-
self.dtype = dtype
87+
self.weights_dtype = weights_dtype
8688
self.recipe = recipe
8789
self.replace_layers = replace_layers
90+
self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
8891

8992
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
9093
# avoid converting if any is found. assume the user took care of it
91-
if self.replace_layers and not any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
94+
if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
95+
if self.replace_layers is True:
96+
# info level because this is expected with `init_module`
97+
rank_zero_info(
98+
"`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains"
99+
" TransformerEngine layers. Skipping"
100+
)
101+
elif self.replace_layers in (None, True):
92102
_convert_layers(module)
93-
module = module.to(dtype=self.dtype)
103+
module = module.to(dtype=self.weights_dtype)
94104
return module
95105

96106
def tensor_init_context(self) -> ContextManager:
97-
return _DtypeContextManager(self.dtype)
107+
return _DtypeContextManager(self.weights_dtype)
98108

99109
def module_init_context(self) -> ContextManager:
100110
dtype_ctx = self.tensor_init_context()
@@ -113,17 +123,20 @@ def module_init_context(self) -> ContextManager:
113123
return stack
114124

115125
def forward_context(self) -> ContextManager:
116-
dtype_ctx = _DtypeContextManager(self.dtype)
126+
dtype_ctx = _DtypeContextManager(self.weights_dtype)
127+
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
117128
import transformer_engine.pytorch as te
118129

119130
autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
120131
stack = ExitStack()
121132
stack.enter_context(dtype_ctx)
133+
# enable an outer fallback autocast for operations that do not support fp8
134+
stack.enter_context(fallback_autocast_ctx)
122135
stack.enter_context(autocast_ctx)
123136
return stack
124137

125138
def convert_input(self, data: Any) -> Any:
126-
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)
139+
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
127140

128141
def convert_output(self, data: Any) -> Any:
129142
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

src/lightning/pytorch/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- The Trainer now restores the training mode set through `.train()` or `.eval()` on a submodule-level when switching from validation to training ([#18951](https://github.com/Lightning-AI/lightning/pull/18951))
1616

1717

18+
- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
19+
20+
1821
### Changed
1922

2023
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
@@ -32,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3235
- The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036))
3336

3437

38+
- Changed the `TransformerEnginePrecision(dtype=)` argument to `weights_dtype` and made it required ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
39+
40+
3541
### Deprecated
3642

3743
- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))
@@ -65,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6571
- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))
6672

6773

74+
- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
75+
76+
6877
## [2.1.2] - 2023-11-15
6978

7079
### Fixed

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,9 @@ def _check_and_init_precision(self) -> Precision:
549549
if self._precision_flag == "64-true":
550550
return DoublePrecision()
551551
if self._precision_flag == "transformer-engine":
552-
return TransformerEnginePrecision(dtype=torch.bfloat16)
552+
return TransformerEnginePrecision(weights_dtype=torch.bfloat16)
553553
if self._precision_flag == "transformer-engine-float16":
554-
return TransformerEnginePrecision(dtype=torch.float16)
554+
return TransformerEnginePrecision(weights_dtype=torch.float16)
555555

556556
if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu":
557557
rank_zero_warn(

tests/tests_fabric/plugins/precision/test_transformer_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,20 @@ def test_transformer_engine_plugin(monkeypatch):
3535

3636
connector = _Connector(precision="transformer-engine")
3737
assert isinstance(connector.precision, TransformerEnginePrecision)
38-
assert connector.precision.dtype is torch.bfloat16
38+
assert connector.precision.weights_dtype is torch.bfloat16
3939
connector = _Connector(precision="transformer-engine-float16")
40-
assert connector.precision.dtype is torch.float16
40+
assert connector.precision.weights_dtype is torch.float16
4141

4242
recipe_mock.reset_mock()
43-
precision = TransformerEnginePrecision()
43+
precision = TransformerEnginePrecision(weights_dtype=torch.float32)
4444
connector = _Connector(plugins=precision)
4545
assert connector.precision is precision
46-
assert precision.dtype == torch.float32
46+
assert precision.weights_dtype == torch.float32
4747
recipe_mock.DelayedScaling.assert_called_once_with()
4848

4949
recipe_mock.reset_mock()
5050
recipe = {"foo": 0, "fp8_format": "HYBRID"}
51-
precision = TransformerEnginePrecision(dtype=torch.float16, recipe=recipe)
51+
precision = TransformerEnginePrecision(weights_dtype=torch.float16, recipe=recipe)
5252
connector = _Connector(plugins=precision)
5353
assert connector.precision is precision
5454
recipe_mock.DelayedScaling.assert_called_once_with(foo=0, fp8_format=recipe_mock.Format.HYBRID)

tests/tests_pytorch/deprecated_api/test_no_removal_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_transformer_engine_precision_plugin(monkeypatch):
125125
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin
126126

127127
with pytest.deprecated_call(match=r"The `TransformerEnginePrecisionPlugin` is deprecated"):
128-
TransformerEnginePrecisionPlugin()
128+
TransformerEnginePrecisionPlugin(weights_dtype=torch.float32)
129129

130130

131131
def test_xla_precision_plugin(xla_available):

tests/tests_pytorch/plugins/precision/test_transformer_engine.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import sys
15-
from unittest.mock import Mock
15+
from contextlib import nullcontext
16+
from unittest.mock import ANY, Mock
1617

18+
import lightning.fabric
1719
import pytest
1820
import torch
21+
from lightning.pytorch import LightningModule, Trainer
1922
from lightning.pytorch.plugins import TransformerEnginePrecision
2023
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector
2124

2225

2326
def test_transformer_engine_precision_plugin(monkeypatch):
24-
import lightning.fabric # avoid breakage with standalone package
25-
2627
module = lightning.fabric.plugins.precision.transformer_engine
2728
if module._TRANSFORMER_ENGINE_AVAILABLE:
2829
pytest.skip("Assumes transformer_engine is unavailable")
@@ -32,10 +33,44 @@ def test_transformer_engine_precision_plugin(monkeypatch):
3233

3334
connector = _AcceleratorConnector(precision="transformer-engine")
3435
assert isinstance(connector.precision_plugin, TransformerEnginePrecision)
35-
assert connector.precision_plugin.dtype is torch.bfloat16
36+
assert connector.precision_plugin.weights_dtype is torch.bfloat16
3637
connector = _AcceleratorConnector(precision="transformer-engine-float16")
37-
assert connector.precision_plugin.dtype is torch.float16
38+
assert connector.precision_plugin.weights_dtype is torch.float16
3839

39-
precision = TransformerEnginePrecision()
40+
precision = TransformerEnginePrecision(weights_dtype=torch.float32)
4041
connector = _AcceleratorConnector(plugins=precision)
4142
assert connector.precision_plugin is precision
43+
44+
45+
def test_configure_model(monkeypatch):
46+
module = lightning.fabric.plugins.precision.transformer_engine
47+
if module._TRANSFORMER_ENGINE_AVAILABLE:
48+
pytest.skip("Assumes transformer_engine is unavailable")
49+
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
50+
te_mock = Mock()
51+
te_mock.pytorch.fp8_autocast.return_value = nullcontext()
52+
53+
class ModuleMock(torch.nn.Linear):
54+
def __init__(self, in_features, out_features, bias=True, *_, **__):
55+
super().__init__(in_features, out_features, bias)
56+
57+
te_mock.pytorch.Linear = ModuleMock
58+
monkeypatch.setitem(sys.modules, "transformer_engine", te_mock)
59+
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", te_mock)
60+
monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", te_mock)
61+
62+
class MyModel(LightningModule):
63+
def configure_model(self):
64+
self.l = torch.nn.Linear(8, 16)
65+
assert self.l.weight.dtype == torch.float16
66+
67+
def test_step(self, *_):
68+
...
69+
70+
model = MyModel()
71+
trainer = Trainer(barebones=True, precision="transformer-engine-float16")
72+
trainer.test(model, [0])
73+
te_mock.pytorch.fp8_autocast.assert_called_once_with(enabled=True, fp8_recipe=ANY)
74+
# TODO: invert condition once this gets fixed
75+
assert not isinstance(model.l, ModuleMock)
76+
assert model.l.weight.dtype == torch.float16

0 commit comments

Comments
 (0)