diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index c3ef84a453e73..bf1e51ea6b2b0 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -171,7 +171,9 @@ def _convert_layers(module: torch.nn.Module) -> None: elif isinstance(child, torch.nn.LayerNorm): replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps) replacement.weight.data = child.weight.data.clone() - replacement.bias.data = child.bias.data.clone() + # Check if bias exists before attempting to clone its data + if child.bias is not None and replacement.bias is not None: + replacement.bias.data = child.bias.data.clone() log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") module.__setattr__(name, replacement) else: diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index 033484aca9c90..ed7c984b1ae64 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -115,3 +115,35 @@ class TELayerNormMock(Mock): ... assert isinstance(model.l1, TELinearMock) assert isinstance(model.l2, TELayerNormMock) assert isinstance(model.l3.l, TELinearMock) + + +def test_convert_module_handles_linear_without_bias(monkeypatch): + module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine + monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) + + transformer_engine_mock = Mock() + monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", transformer_engine_mock.pytorch) + monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", transformer_engine_mock.recipe) + + class TELinearMock(torch.nn.Linear): # Mock the Linear replacement class + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features, bias) + + transformer_engine_mock.pytorch.Linear = TELinearMock + transformer_engine_mock.pytorch.LayerNorm = torch.nn.LayerNorm + transformer_engine_mock.recipe.DelayedScaling.return_value = None + + class BiaslessModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 32, bias=False) # This was causing the bug + + model = BiaslessModel() + precision = TransformerEnginePrecision(weights_dtype=torch.float16) + precision.replace_layers = True + + precision.convert_module(model) # This should no longer raise AttributeError + + assert isinstance(model.linear, TELinearMock) + assert model.linear.bias is None