@@ -723,6 +723,52 @@ def gconv(x, w):
723723 out = mx .vmap (gconv , in_axes = (0 , 0 ))(x , w )
724724 self .assertTrue (mx .allclose (expected , out ))
725725
726+ def test_vmap_types (self ):
727+
728+ from typing import NamedTuple
729+
730+ class Vector (tuple ):
731+ pass
732+
733+ class State (NamedTuple ):
734+ a : mx .array
735+ b : mx .array
736+
737+ def transform (x : State ):
738+ return State (x .a + 10 , x .b * 10 )
739+
740+ def transform_tuple (t ):
741+ return (t [0 ] + 10 , t [1 ] * 10 )
742+
743+ def transform_vector (t ):
744+ return Vector ([t [0 ] + 10 , t [1 ] * 10 ])
745+
746+ x = State (mx .array (1 ), mx .array (2 ))
747+ print (f"{ transform (x )= } " )
748+
749+ vmap_transform = mx .vmap (transform )
750+ vmap_transform_tuple = mx .vmap (transform_tuple )
751+ vmap_transform_vector = mx .vmap (transform_vector )
752+
753+ x_batch_tuple = (mx .array ([1 , 2 , 3 ]), mx .array ([4 , 5 , 6 ]))
754+ out1 = vmap_transform_tuple (x_batch_tuple )
755+
756+ self .assertTrue (isinstance (out1 , tuple ))
757+ self .assertTrue (mx .array_equal (out1 [0 ], mx .array ([11 , 12 , 13 ])))
758+ self .assertTrue (mx .array_equal (out1 [1 ], mx .array ([40 , 50 , 60 ])))
759+
760+ x_batch = State (mx .array ([1 , 2 , 3 ]), mx .array ([4 , 5 , 6 ]))
761+ out2 = vmap_transform (x_batch )
762+ self .assertTrue (isinstance (out2 , State ))
763+ self .assertTrue (mx .array_equal (out2 .a , mx .array ([11 , 12 , 13 ])))
764+ self .assertTrue (mx .array_equal (out2 .b , mx .array ([40 , 50 , 60 ])))
765+
766+ x_batch_vector = Vector ([mx .array ([1 , 2 , 3 ]), mx .array ([4 , 5 , 6 ])])
767+ out3 = vmap_transform_vector (x_batch_vector )
768+ self .assertTrue (isinstance (out3 , Vector ))
769+ self .assertTrue (mx .array_equal (out3 [0 ], mx .array ([11 , 12 , 13 ])))
770+ self .assertTrue (mx .array_equal (out3 [1 ], mx .array ([40 , 50 , 60 ])))
771+
726772 def test_vmap_masked_scatter (self ):
727773 def scatter_fn (x , m , src ):
728774 x [m ] = src
0 commit comments