Skip to content

Commit d2702a4

Browse files
gmin7Michelle DiMarco
andauthored
Fix non-strict module update with extra weights (#3214)
Co-authored-by: Michelle DiMarco <m_dimarco@apple.com>
1 parent 6ac5280 commit d2702a4

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

python/mlx/nn/layers/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,13 @@ def apply(dst, parameters):
340340
raise ValueError(f'Module does not have parameter named "{k}".')
341341
elif isinstance(parameters, list):
342342
for i in range(len(parameters)):
343+
if i >= len(dst):
344+
if strict:
345+
raise ValueError(
346+
f"List index {i} is out of bounds for "
347+
f"destination of length {len(dst)}."
348+
)
349+
continue
343350
current_value = dst[i]
344351
new_value = parameters[i]
345352
if isinstance(current_value, mx.array):

python/tests/test_nn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ def test_load_from_weights(self):
170170
# Empty weights is ok if strict is false
171171
m.load_weights([], strict=False)
172172

173+
# Extra weights for non-existent layers are filtered when strict
174+
# is false. Flat keys like "extra.weight" are silently dropped by
175+
# Module.update, but nested indexed keys like "layers.1.weight"
176+
# cause an IndexError in tree_unflatten/update without filtering.
177+
m = nn.Sequential(nn.Linear(2, 2))
178+
m.load_weights(
179+
[
180+
("layers.0.weight", mx.ones((2, 2))),
181+
("layers.0.bias", mx.ones((2,))),
182+
("layers.1.weight", mx.ones((2, 2))),
183+
("layers.1.bias", mx.ones((2,))),
184+
],
185+
strict=False,
186+
)
187+
self.assertTrue(mx.array_equal(m.layers[0].weight, mx.ones((2, 2))))
188+
self.assertEqual(len(m.layers), 1)
189+
173190
def test_module_state(self):
174191
m = nn.Linear(10, 1)
175192
m.state["hello"] = "world"

0 commit comments

Comments
 (0)