@@ -43,10 +43,10 @@ def sharded_state_dict(
4343
4444
4545class SimpleMLP (Layer ):
46- def __init__ (self , hidden_size = 1024 ):
46+ def __init__ (self , in_features = 1024 , out_features = 1024 ):
4747 super ().__init__ ()
4848 self .linear = ColumnParallelLinear (
49- hidden_size , hidden_size * 2 , has_bias = True
49+ in_features , out_features , has_bias = True
5050 )
5151
5252 def forward (self , x ):
@@ -55,10 +55,10 @@ def forward(self, x):
5555
5656
5757class SimpleMLPTransWeight (Layer ):
58- def __init__ (self , hidden_size = 1024 ):
58+ def __init__ (self , in_features = 1024 , out_features = 1024 ):
5959 super ().__init__ ()
6060 self .linear = ColumnParallelLinearTransWeight (
61- hidden_size , hidden_size * 2 , has_bias = True
61+ in_features , out_features , has_bias = True
6262 )
6363
6464 def forward (self , x ):
@@ -70,6 +70,8 @@ class TestLoadStateDictTransposeLogic:
7070 def __init__ (self ):
7171 self .aoa_config = {"aoa_statements" : [os .getenv ("aoa_statements" )]}
7272 self .ckpt_path = tempfile .TemporaryDirectory ().name
73+ self .in_features = 1024
74+ self .out_features = 2048
7375
7476 def run_test (self ):
7577 self .run_save_state_dict ()
@@ -99,5 +101,13 @@ def run_save_state_dict(self):
99101 dist .save_state_dict (sharded_state_dict , self .ckpt_path )
100102
101103
104+ class TestLoadStateDictTransposeLogic2 (TestLoadStateDictTransposeLogic ):
105+ def __init__ (self ):
106+ super ().__init__ ()
107+ self .in_features = 1024
108+ self .out_features = 1024
109+
110+
102111if __name__ == '__main__' :
103112 TestLoadStateDictTransposeLogic ().run_test ()
113+ TestLoadStateDictTransposeLogic2 ().run_test ()
0 commit comments