Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/mlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def tree_map(
return fn(tree, *rest)
elif isinstance(tree, (list, tuple)):
TreeType = type(tree)
return TreeType(
subtrees = [
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
for i, child in enumerate(tree)
)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be a generator, no need to actually expand a new list here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, sent the changes.

return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees)
elif isinstance(tree, dict):
return {
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
Expand Down
12 changes: 10 additions & 2 deletions python/src/trees.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ nb::object tree_map(
int len = nb::cast<nb::tuple>(subtrees[0]).size();
nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
auto type = subtrees[0].type();
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
Expand All @@ -51,7 +52,10 @@ nb::object tree_map(
}
l.append(recurse(items));
}
return nb::cast<nb::object>(nb::tuple(l));
if (PyTuple_CheckExact(subtrees[0].ptr())) {
return nb::cast<nb::object>(nb::tuple(l));
}
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
Expand Down Expand Up @@ -178,11 +182,15 @@ void tree_visit_update(
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
auto type = subtree.type();
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(nb::tuple(l));
if (PyTuple_CheckExact(subtree.ptr())) {
return nb::cast<nb::object>(nb::tuple(l));
}
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {
Expand Down
49 changes: 49 additions & 0 deletions python/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,55 @@ def loss_fn(model):
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)

def test_autograd_types(self):
from typing import NamedTuple

class Vector(tuple):
pass

class State(NamedTuple):
a: mx.array
b: mx.array

def transform(x: State):
return State(x.a + 10, x.b * 10)

def transform_tuple(t):
return (t[0] + 10, t[1] * 10)

def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])

def loss_fn(x):
out = transform(x)
return out.a.sum() + out.b.sum()

def loss_fn_tuple(x):
out = transform_tuple(x)
return out[0].sum() + out[1].sum()

def loss_fn_vector(x):
out = transform_vector(x)
return out[0].sum() + out[1].sum()

x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
grads = mx.grad(loss_fn)(x_batch)
self.assertTrue(isinstance(grads, State))
self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))
self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))

x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
grads = mx.grad(loss_fn_tuple)(x_batch_tuple)
self.assertTrue(isinstance(grads, tuple))
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))

x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
grads = mx.grad(loss_fn_vector)(x_batch_vector)
self.assertTrue(isinstance(grads, Vector))
self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))
self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
44 changes: 44 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,50 @@ def fun(do_compile):
expected = fun(False)
self.assertTrue(mx.allclose(out, expected))

def test_compile_types(self):
from typing import NamedTuple

class Vector(tuple):
pass

class State(NamedTuple):
a: mx.array
b: mx.array

def transform(x: State):
return State(x.a + 10, x.b * 10)

def transform_tuple(t):
return (t[0] + 10, t[1] * 10)

def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])

x = State(mx.array(1), mx.array(2))

compiled_transform = mx.compile(transform)
compiled_transform_tuple = mx.compile(transform_tuple)
compiled_transform_vector = mx.compile(transform_vector)

x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = compiled_transform_tuple(x_batch_tuple)

self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))

x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = compiled_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))

x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = compiled_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
45 changes: 45 additions & 0 deletions python/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,51 @@ def test_merge(self):
self.assertEqual(k1, k2)
self.assertTrue(mx.array_equal(v1, v2))

def test_supported_trees(self):

from typing import NamedTuple

class Vector(tuple):
pass

class Params(NamedTuple):
m: mx.array
b: mx.array

list1 = [mx.array([0, 1]), mx.array(2)]
tuple1 = (mx.array([0, 1]), mx.array(2))
vector1 = Vector([mx.array([0, 1]), mx.array(2)])
params1 = Params(m=mx.array([0, 1]), b=mx.array(2))
dict1 = {"m": mx.array([0, 1]), "b": mx.array(2)}

add_one = lambda x: x + 1

list2 = mlx.utils.tree_map(add_one, list1)
tuple2 = mlx.utils.tree_map(add_one, tuple1)
vector2 = mlx.utils.tree_map(add_one, vector1)
params2 = mlx.utils.tree_map(add_one, params1)
dict2 = mlx.utils.tree_map(add_one, dict1)

self.assertTrue(isinstance(list2, list))
self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(list2[1], mx.array(3)))

self.assertTrue(isinstance(tuple2, tuple))
self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))

self.assertTrue(isinstance(vector2, Vector))
self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))
self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))

self.assertTrue(isinstance(dict2, dict))
self.assertTrue(mx.array_equal(dict2["m"], mx.array([1, 2])))
self.assertTrue(mx.array_equal(dict2["b"], mx.array(3)))

self.assertTrue(isinstance(params2, Params))
self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))
self.assertTrue(mx.array_equal(params2.b, mx.array(3)))


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
46 changes: 46 additions & 0 deletions python/tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,52 @@ def gconv(x, w):
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))

def test_vmap_types(self):

from typing import NamedTuple

class Vector(tuple):
pass

class State(NamedTuple):
a: mx.array
b: mx.array

def transform(x: State):
return State(x.a + 10, x.b * 10)

def transform_tuple(t):
return (t[0] + 10, t[1] * 10)

def transform_vector(t):
return Vector([t[0] + 10, t[1] * 10])

x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")

vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)
vmap_transform_vector = mx.vmap(transform_vector)

x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out1 = vmap_transform_tuple(x_batch_tuple)

self.assertTrue(isinstance(out1, tuple))
self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))

x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))
out2 = vmap_transform(x_batch)
self.assertTrue(isinstance(out2, State))
self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))

x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])
out3 = vmap_transform_vector(x_batch_vector)
self.assertTrue(isinstance(out3, Vector))
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))

def test_vmap_masked_scatter(self):
def scatter_fn(x, m, src):
x[m] = src
Expand Down