Skip to content

Commit 9abb0b8

Browse files
authored
Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)
1 parent 50d3914 commit 9abb0b8

File tree

6 files changed

+196
-3
lines changed

6 files changed

+196
-3
lines changed

python/mlx/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ def tree_map(
4444
return fn(tree, *rest)
4545
elif isinstance(tree, (list, tuple)):
4646
TreeType = type(tree)
47-
return TreeType(
47+
subtrees = (
4848
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
4949
for i, child in enumerate(tree)
5050
)
51+
return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees)
5152
elif isinstance(tree, dict):
5253
return {
5354
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)

python/src/trees.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ nb::object tree_map(
4141
int len = nb::cast<nb::tuple>(subtrees[0]).size();
4242
nb::list l;
4343
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
44+
auto type = subtrees[0].type();
4445
for (int i = 0; i < len; ++i) {
4546
for (int j = 0; j < subtrees.size(); ++j) {
4647
if (nb::isinstance<nb::tuple>(subtrees[j])) {
@@ -51,7 +52,10 @@ nb::object tree_map(
5152
}
5253
l.append(recurse(items));
5354
}
54-
return nb::cast<nb::object>(nb::tuple(l));
55+
if (PyTuple_CheckExact(subtrees[0].ptr())) {
56+
return nb::cast<nb::object>(nb::tuple(l));
57+
}
58+
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
5559
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
5660
std::vector<nb::object> items(subtrees.size());
5761
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
@@ -178,11 +182,15 @@ void tree_visit_update(
178182
}
179183
return nb::cast<nb::object>(l);
180184
} else if (nb::isinstance<nb::tuple>(subtree)) {
185+
auto type = subtree.type();
181186
nb::list l(subtree);
182187
for (int i = 0; i < l.size(); ++i) {
183188
l[i] = recurse(l[i]);
184189
}
185-
return nb::cast<nb::object>(nb::tuple(l));
190+
if (PyTuple_CheckExact(subtree.ptr())) {
191+
return nb::cast<nb::object>(nb::tuple(l));
192+
}
193+
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
186194
} else if (nb::isinstance<nb::dict>(subtree)) {
187195
auto d = nb::cast<nb::dict>(subtree);
188196
for (auto item : d) {

python/tests/test_autograd.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,55 @@ def loss_fn(model):
798798
grad_fn(model)
799799
self.assertEqual(model[1].item(), 2.0)
800800

801+
def test_autograd_types(self):
802+
from typing import NamedTuple
803+
804+
class Vector(tuple):
805+
pass
806+
807+
class State(NamedTuple):
808+
a: mx.array
809+
b: mx.array
810+
811+
def transform(x: State):
812+
return State(x.a + 10, x.b * 10)
813+
814+
def transform_tuple(t):
815+
return (t[0] + 10, t[1] * 10)
816+
817+
def transform_vector(t):
818+
return Vector([t[0] + 10, t[1] * 10])
819+
820+
def loss_fn(x):
821+
out = transform(x)
822+
return out.a.sum() + out.b.sum()
823+
824+
def loss_fn_tuple(x):
825+
out = transform_tuple(x)
826+
return out[0].sum() + out[1].sum()
827+
828+
def loss_fn_vector(x):
829+
out = transform_vector(x)
830+
return out[0].sum() + out[1].sum()
831+
832+
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
833+
grads = mx.grad(loss_fn)(x_batch)
834+
self.assertTrue(isinstance(grads, State))
835+
self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))
836+
self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))
837+
838+
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
839+
grads = mx.grad(loss_fn_tuple)(x_batch_tuple)
840+
self.assertTrue(isinstance(grads, tuple))
841+
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
842+
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
843+
844+
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
845+
grads = mx.grad(loss_fn_vector)(x_batch_vector)
846+
self.assertTrue(isinstance(grads, Vector))
847+
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
848+
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))
849+
801850
def test_reduce_jvp(self):
802851
a = mx.arange(4)
803852
b = mx.array([3, 2, 1, 0])

python/tests/test_compile.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,50 @@ def fun(do_compile):
11791179
expected = fun(False)
11801180
self.assertTrue(mx.allclose(out, expected))
11811181

1182+
def test_compile_types(self):
1183+
from typing import NamedTuple
1184+
1185+
class Vector(tuple):
1186+
pass
1187+
1188+
class State(NamedTuple):
1189+
a: mx.array
1190+
b: mx.array
1191+
1192+
def transform(x: State):
1193+
return State(x.a + 10, x.b * 10)
1194+
1195+
def transform_tuple(t):
1196+
return (t[0] + 10, t[1] * 10)
1197+
1198+
def transform_vector(t):
1199+
return Vector([t[0] + 10, t[1] * 10])
1200+
1201+
x = State(mx.array(1), mx.array(2))
1202+
1203+
compiled_transform = mx.compile(transform)
1204+
compiled_transform_tuple = mx.compile(transform_tuple)
1205+
compiled_transform_vector = mx.compile(transform_vector)
1206+
1207+
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
1208+
out1 = compiled_transform_tuple(x_batch_tuple)
1209+
1210+
self.assertTrue(isinstance(out1, tuple))
1211+
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
1212+
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
1213+
1214+
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
1215+
out2 = compiled_transform(x_batch)
1216+
self.assertTrue(isinstance(out2, State))
1217+
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
1218+
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
1219+
1220+
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
1221+
out3 = compiled_transform_vector(x_batch_vector)
1222+
self.assertTrue(isinstance(out3, Vector))
1223+
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
1224+
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
1225+
11821226

11831227
if __name__ == "__main__":
11841228
mlx_tests.MLXTestRunner()

python/tests/test_tree.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,51 @@ def test_merge(self):
4646
self.assertEqual(k1, k2)
4747
self.assertTrue(mx.array_equal(v1, v2))
4848

49+
def test_supported_trees(self):
50+
51+
from typing import NamedTuple
52+
53+
class Vector(tuple):
54+
pass
55+
56+
class Params(NamedTuple):
57+
m: mx.array
58+
b: mx.array
59+
60+
list1 = [mx.array([0, 1]), mx.array(2)]
61+
tuple1 = (mx.array([0, 1]), mx.array(2))
62+
vector1 = Vector([mx.array([0, 1]), mx.array(2)])
63+
params1 = Params(m=mx.array([0, 1]), b=mx.array(2))
64+
dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)}
65+
66+
add_one = lambda x: x + 1
67+
68+
list2 = mlx.utils.tree_map(add_one, list1)
69+
tuple2 = mlx.utils.tree_map(add_one, tuple1)
70+
vector2 = mlx.utils.tree_map(add_one, vector1)
71+
params2 = mlx.utils.tree_map(add_one, params1)
72+
dict2 = mlx.utils.tree_map(add_one, dict1)
73+
74+
self.assertTrue(isinstance(list2, list))
75+
self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))
76+
self.assertTrue(mx.array_equal(list2[1], mx.array(3)))
77+
78+
self.assertTrue(isinstance(tuple2, tuple))
79+
self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))
80+
self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))
81+
82+
self.assertTrue(isinstance(vector2, Vector))
83+
self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))
84+
self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))
85+
86+
self.assertTrue(isinstance(dict2, dict))
87+
self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2])))
88+
self.assertTrue(mx.array_equal(dict2["b"], mx.array(3)))
89+
90+
self.assertTrue(isinstance(params2, Params))
91+
self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))
92+
self.assertTrue(mx.array_equal(params2.b, mx.array(3)))
93+
4994

5095
if __name__ == "__main__":
5196
mlx_tests.MLXTestRunner()

python/tests/test_vmap.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,52 @@ def gconv(x, w):
723723
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
724724
self.assertTrue(mx.allclose(expected, out))
725725

726+
def test_vmap_types(self):
727+
728+
from typing import NamedTuple
729+
730+
class Vector(tuple):
731+
pass
732+
733+
class State(NamedTuple):
734+
a: mx.array
735+
b: mx.array
736+
737+
def transform(x: State):
738+
return State(x.a + 10, x.b * 10)
739+
740+
def transform_tuple(t):
741+
return (t[0] + 10, t[1] * 10)
742+
743+
def transform_vector(t):
744+
return Vector([t[0] + 10, t[1] * 10])
745+
746+
x = State(mx.array(1), mx.array(2))
747+
print(f"{transform(x)=}")
748+
749+
vmap_transform = mx.vmap(transform)
750+
vmap_transform_tuple = mx.vmap(transform_tuple)
751+
vmap_transform_vector = mx.vmap(transform_vector)
752+
753+
x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
754+
out1 = vmap_transform_tuple(x_batch_tuple)
755+
756+
self.assertTrue(isinstance(out1, tuple))
757+
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
758+
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))
759+
760+
x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
761+
out2 = vmap_transform(x_batch)
762+
self.assertTrue(isinstance(out2, State))
763+
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
764+
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))
765+
766+
x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
767+
out3 = vmap_transform_vector(x_batch_vector)
768+
self.assertTrue(isinstance(out3, Vector))
769+
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
770+
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
771+
726772
def test_vmap_masked_scatter(self):
727773
def scatter_fn(x, m, src):
728774
x[m] = src

0 commit comments

Comments
 (0)