File tree Expand file tree Collapse file tree 2 files changed +24
-0
lines changed
Expand file tree Collapse file tree 2 files changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -340,6 +340,13 @@ def apply(dst, parameters):
340340 raise ValueError (f'Module does not have parameter named "{ k } ".' )
341341 elif isinstance (parameters , list ):
342342 for i in range (len (parameters )):
343+ if i >= len (dst ):
344+ if strict :
345+ raise ValueError (
346+ f"List index { i } is out of bounds for "
347+ f"destination of length { len (dst )} ."
348+ )
349+ continue
343350 current_value = dst [i ]
344351 new_value = parameters [i ]
345352 if isinstance (current_value , mx .array ):
Original file line number Diff line number Diff line change @@ -170,6 +170,23 @@ def test_load_from_weights(self):
170170 # Empty weights is ok if strict is false
171171 m .load_weights ([], strict = False )
172172
173+ # Extra weights for non-existent layers are filtered when strict
174+ # is false. Flat keys like "extra.weight" are silently dropped by
175+ # Module.update, but nested indexed keys like "layers.1.weight"
176+ # cause an IndexError in tree_unflatten/update without filtering.
177+ m = nn .Sequential (nn .Linear (2 , 2 ))
178+ m .load_weights (
179+ [
180+ ("layers.0.weight" , mx .ones ((2 , 2 ))),
181+ ("layers.0.bias" , mx .ones ((2 ,))),
182+ ("layers.1.weight" , mx .ones ((2 , 2 ))),
183+ ("layers.1.bias" , mx .ones ((2 ,))),
184+ ],
185+ strict = False ,
186+ )
187+ self .assertTrue (mx .array_equal (m .layers [0 ].weight , mx .ones ((2 , 2 ))))
188+ self .assertEqual (len (m .layers ), 1 )
189+
173190 def test_module_state (self ):
174191 m = nn .Linear (10 , 1 )
175192 m .state ["hello" ] = "world"
You can’t perform that action at this time.
0 commit comments