55# LICENSE file in the root directory of this source tree.
66
77import unittest
8+ from dataclasses import dataclass
89
910import torch
1011import torch .nn as nn
1112
12- from torchtitan .protocols .module import Module
13+ from torchtitan .models .common .linear import Linear
14+ from torchtitan .protocols .module import Module , ModuleDict , ModuleList , Sequential
1315
1416
1517class TestModuleInitWeights (unittest .TestCase ):
@@ -36,7 +38,9 @@ def test_init_weights_implemented(self):
3638 class GoodModule (Module ):
3739 def __init__ (self ):
3840 super ().__init__ ()
39- self .linear = nn .Linear (4 , 4 )
41+ self .linear = Linear .Config (bias = True ).build (
42+ in_features = 4 , out_features = 4
43+ )
4044
4145 def init_weights (self , ** kwargs ):
4246 nn .init .zeros_ (self .linear .weight )
@@ -110,11 +114,13 @@ def __init__(self, num_embeddings, embedding_dim):
110114 def test_module_hierarchy_is_flat (self ):
111115 """Diamond embedding adds no extra layer to the module tree."""
112116
113- class Model (nn . Module ):
117+ class Model (Module ):
114118 def __init__ (self ):
115119 super ().__init__ ()
116120 self .embed = TestDiamondInheritance .TestEmbedding (100 , 32 )
117- self .linear = nn .Linear (32 , 16 )
121+ self .linear = Linear .Config (bias = True ).build (
122+ in_features = 32 , out_features = 16
123+ )
118124
119125 model = Model ()
120126 param_names = {name for name , _ in model .named_parameters ()}
@@ -138,5 +144,172 @@ def counting_init(self, *args, **kwargs):
138144 nn .Module .__init__ = orig_init
139145
140146
147+ class TestFromNnModule (unittest .TestCase ):
148+ """Tests for Module.from_nn_module utility."""
149+
150+ def test_is_subclass (self ):
151+ """Created class is subclass of both original and Module."""
152+ Conv2d = Module .from_nn_module (nn .Conv2d )
153+ self .assertTrue (issubclass (Conv2d , nn .Conv2d ))
154+ self .assertTrue (issubclass (Conv2d , Module ))
155+
156+ def test_isinstance (self ):
157+ """Instance satisfies isinstance checks for both original and Module."""
158+ Conv2d = Module .from_nn_module (nn .Conv2d )
159+ m = Conv2d (3 , 16 , 3 )
160+ self .assertIsInstance (m , nn .Conv2d )
161+ self .assertIsInstance (m , Module )
162+
163+ def test_init_weights_calls_reset_parameters (self ):
164+ """For classes with reset_parameters, init_weights delegates to it."""
165+ LayerNorm = Module .from_nn_module (nn .LayerNorm )
166+ m = LayerNorm (32 )
167+ # Manually set weight to zeros, then init_weights should reset
168+ nn .init .zeros_ (m .weight )
169+ m .init_weights ()
170+ # After reset_parameters, weight should be ones for LayerNorm
171+ self .assertTrue (torch .allclose (m .weight , torch .ones (32 )))
172+
173+ def test_init_weights_noop_for_parameterless (self ):
174+ """For classes without reset_parameters, init_weights is a no-op."""
175+ GELU = Module .from_nn_module (nn .GELU )
176+ m = GELU ()
177+ m .init_weights () # should not raise
178+
179+ def test_cache (self ):
180+ """Repeated calls return the same class object."""
181+ cls1 = Module .from_nn_module (nn .Conv2d )
182+ cls2 = Module .from_nn_module (nn .Conv2d )
183+ self .assertIs (cls1 , cls2 )
184+
185+ def test_forward_unchanged (self ):
186+ """Forward output is identical to original class."""
187+ LayerNorm = Module .from_nn_module (nn .LayerNorm )
188+ torch .manual_seed (42 )
189+ orig = nn .LayerNorm (16 )
190+ wrapped = LayerNorm (16 )
191+ # Copy weights
192+ wrapped .load_state_dict (orig .state_dict ())
193+ x = torch .randn (2 , 16 )
194+ torch .testing .assert_close (orig (x ), wrapped (x ))
195+
196+ def test_state_dict_unchanged (self ):
197+ """state_dict keys and values match the original class."""
198+ Conv2d = Module .from_nn_module (nn .Conv2d )
199+ orig = nn .Conv2d (3 , 16 , 3 )
200+ wrapped = Conv2d (3 , 16 , 3 )
201+ wrapped .load_state_dict (orig .state_dict ())
202+ for key in orig .state_dict ():
203+ self .assertIn (key , wrapped .state_dict ())
204+ torch .testing .assert_close (
205+ orig .state_dict ()[key ], wrapped .state_dict ()[key ]
206+ )
207+
208+
209+ class TestContainerInitWeights (unittest .TestCase ):
210+ """Tests for ModuleList, ModuleDict, Sequential init_weights."""
211+
212+ def test_module_list_init_weights (self ):
213+ """ModuleList.init_weights calls init_weights on each child."""
214+ LayerNorm = Module .from_nn_module (nn .LayerNorm )
215+ norms = ModuleList ([LayerNorm (8 ) for _ in range (3 )])
216+ for n in norms :
217+ nn .init .zeros_ (n .weight )
218+ norms .init_weights ()
219+ for n in norms :
220+ self .assertTrue (torch .allclose (n .weight , torch .ones (8 )))
221+
222+ def test_module_dict_init_weights (self ):
223+ """ModuleDict.init_weights calls init_weights on each child."""
224+ LayerNorm = Module .from_nn_module (nn .LayerNorm )
225+ norms = ModuleDict ({"a" : LayerNorm (8 ), "b" : LayerNorm (8 )})
226+ for n in norms .values ():
227+ nn .init .zeros_ (n .weight )
228+ norms .init_weights ()
229+ for n in norms .values ():
230+ self .assertTrue (torch .allclose (n .weight , torch .ones (8 )))
231+
232+ def test_sequential_init_weights (self ):
233+ """Sequential.init_weights calls init_weights on each child."""
234+ linear = Linear .Config (bias = False ).build (in_features = 4 , out_features = 4 )
235+ GELU = Module .from_nn_module (nn .GELU )
236+ seq = Sequential (linear , GELU ())
237+ seq .init_weights () # should not raise
238+
239+ def test_containers_are_module (self ):
240+ """Container instances satisfy Module protocol."""
241+ self .assertIsInstance (ModuleList (), Module )
242+ self .assertIsInstance (ModuleDict (), Module )
243+ self .assertIsInstance (Sequential (), Module )
244+
245+
246+ class TestVerifyModuleProtocol (unittest .TestCase ):
247+ """Tests for BaseModel.verify_module_protocol."""
248+
249+ def test_passes_for_all_module (self ):
250+ """No error when all submodules are Module instances."""
251+ from torchtitan .protocols .model import BaseModel
252+
253+ class GoodModel (BaseModel ):
254+ @dataclass (kw_only = True , slots = True )
255+ class Config (BaseModel .Config ):
256+ def update_from_config (self , * , trainer_config , ** kwargs ):
257+ pass
258+
259+ def get_nparams_and_flops (self , model , seq_len ):
260+ return (0 , 0 )
261+
262+ def __init__ (self ):
263+ super ().__init__ ()
264+ self .linear = Linear .Config ().build (in_features = 4 , out_features = 4 )
265+
266+ model = GoodModel ()
267+ model .verify_module_protocol () # should not raise
268+
269+ def test_default_raises_for_plain_nn_module (self ):
270+ """Default verify_module_protocol raises when plain nn.Module child exists."""
271+ from torchtitan .protocols .model import BaseModel
272+
273+ class BadModel (BaseModel ):
274+ @dataclass (kw_only = True , slots = True )
275+ class Config (BaseModel .Config ):
276+ def update_from_config (self , * , trainer_config , ** kwargs ):
277+ pass
278+
279+ def get_nparams_and_flops (self , model , seq_len ):
280+ return (0 , 0 )
281+
282+ def __init__ (self ):
283+ super ().__init__ ()
284+ self .plain = nn .Linear (4 , 4 )
285+
286+ model = BadModel ()
287+ with self .assertRaises (RuntimeError ):
288+ model .verify_module_protocol ()
289+
290+ def test_override_skips_verification (self ):
291+ """Subclass can override verify_module_protocol to skip verification."""
292+ from torchtitan .protocols .model import BaseModel
293+
294+ class ThirdPartyModel (BaseModel ):
295+ @dataclass (kw_only = True , slots = True )
296+ class Config (BaseModel .Config ):
297+ def update_from_config (self , * , trainer_config , ** kwargs ):
298+ pass
299+
300+ def get_nparams_and_flops (self , model , seq_len ):
301+ return (0 , 0 )
302+
303+ def __init__ (self ):
304+ super ().__init__ ()
305+ self .plain = nn .Linear (4 , 4 ) # third-party module
306+
307+ def verify_module_protocol (self ) -> None :
308+ pass # skip for third-party internals
309+
310+ model = ThirdPartyModel ()
311+ model .verify_module_protocol () # should not raise
312+
313+
141314if __name__ == "__main__" :
142315 unittest .main ()
0 commit comments