@@ -115,3 +115,35 @@ class TELayerNormMock(Mock): ...
115115 assert isinstance (model .l1 , TELinearMock )
116116 assert isinstance (model .l2 , TELayerNormMock )
117117 assert isinstance (model .l3 .l , TELinearMock )
118+
119+
120+ def test_convert_module_handles_linear_without_bias (monkeypatch ):
121+ module = lightning .fabric .plugins .precision .transformer_engine # Set up mock transformer_engine
122+ monkeypatch .setattr (module , "_TRANSFORMER_ENGINE_AVAILABLE" , lambda : True )
123+
124+ transformer_engine_mock = Mock ()
125+ monkeypatch .setitem (sys .modules , "transformer_engine" , transformer_engine_mock )
126+ monkeypatch .setitem (sys .modules , "transformer_engine.pytorch" , transformer_engine_mock .pytorch )
127+ monkeypatch .setitem (sys .modules , "transformer_engine.common.recipe" , transformer_engine_mock .recipe )
128+
129+ class TELinearMock (torch .nn .Linear ): # Mock the Linear replacement class
130+ def __init__ (self , in_features , out_features , bias = True ):
131+ super ().__init__ (in_features , out_features , bias )
132+
133+ transformer_engine_mock .pytorch .Linear = TELinearMock
134+ transformer_engine_mock .pytorch .LayerNorm = torch .nn .LayerNorm
135+ transformer_engine_mock .recipe .DelayedScaling .return_value = None
136+
137+ class BiaslessModel (torch .nn .Module ):
138+ def __init__ (self ):
139+ super ().__init__ ()
140+ self .linear = torch .nn .Linear (16 , 32 , bias = False ) # This was causing the bug
141+
142+ model = BiaslessModel ()
143+ precision = TransformerEnginePrecision (weights_dtype = torch .float16 )
144+ precision .replace_layers = True
145+
146+ precision .convert_module (model ) # This should no longer raise AttributeError
147+
148+ assert isinstance (model .linear , TELinearMock )
149+ assert model .linear .bias is None
0 commit comments