Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def _convert_layers(module: torch.nn.Module) -> None:
elif isinstance(child, torch.nn.LayerNorm):
replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
replacement.weight.data = child.weight.data.clone()
replacement.bias.data = child.bias.data.clone()
# Check if bias exists before attempting to clone its data
if child.bias is not None and replacement.bias is not None:
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
else:
Expand Down
32 changes: 32 additions & 0 deletions tests/tests_fabric/plugins/precision/test_transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,35 @@ class TELayerNormMock(Mock): ...
assert isinstance(model.l1, TELinearMock)
assert isinstance(model.l2, TELayerNormMock)
assert isinstance(model.l3.l, TELinearMock)


def test_convert_module_handles_linear_without_bias(monkeypatch):
module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)

transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", transformer_engine_mock.pytorch)
monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", transformer_engine_mock.recipe)

class TELinearMock(torch.nn.Linear): # Mock the Linear replacement class
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)

transformer_engine_mock.pytorch.Linear = TELinearMock
transformer_engine_mock.pytorch.LayerNorm = torch.nn.LayerNorm
transformer_engine_mock.recipe.DelayedScaling.return_value = None

class BiaslessModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 32, bias=False) # This was causing the bug

model = BiaslessModel()
precision = TransformerEnginePrecision(weights_dtype=torch.float16)
precision.replace_layers = True

precision.convert_module(model) # This should no longer raise AttributeError

assert isinstance(model.linear, TELinearMock)
assert model.linear.bias is None
Loading