@@ -719,3 +719,46 @@ def add_axis(self, x, new_position):
719719
720720 def einsum (self , pattern , * x ):
721721 return self .pt .einsum (pattern , * x )
722+
723+
724+ class MLXBackend (AbstractBackend ):
725+ framework_name = "mlx"
726+
727+ def __init__ (self ):
728+ import mlx .core as mx
729+ import numpy as np
730+
731+ self .mx = mx
732+ self .np = np
733+
734+ def is_appropriate_type (self , tensor ):
735+ return isinstance (tensor , self .mx .array )
736+
737+ def from_numpy (self , x ):
738+ return self .mx .array (x )
739+
740+ def to_numpy (self , x ):
741+ if x .dtype == self .mx .bfloat16 :
742+ x = x .astype (self .mx .float32 )
743+ return self .np .array (x )
744+
745+ def arange (self , start , stop ):
746+ return self .mx .arange (start , stop )
747+
748+ def stack_on_zeroth_dimension (self , tensors : list ):
749+ return self .mx .stack (tensors )
750+
751+ def add_axes (self , x , new_position ):
752+ return self .mx .expand_dims (x , new_position )
753+
754+ def tile (self , x , repeats ):
755+ return self .mx .tile (x , repeats )
756+
757+ def concat (self , tensors , axis : int ):
758+ return self .mx .concatenate (tensors , axis = axis )
759+
760+ def is_float_type (self , x ):
761+ return self .mx .issubdtype (x .dtype , self .mx .floating )
762+
763+ def einsum (self , pattern , * x ):
764+ return self .mx .einsum (pattern , * x )
0 commit comments