Skip to content

Commit 3cdcbd5

Browse files
authored
add mlx backend (#376)
1 parent 5197457 commit 3cdcbd5

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

einops/_backends.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,3 +719,46 @@ def add_axis(self, x, new_position):
719719

720720
def einsum(self, pattern, *x):
721721
return self.pt.einsum(pattern, *x)
722+
723+
724+
class MLXBackend(AbstractBackend):
725+
framework_name = "mlx"
726+
727+
def __init__(self):
728+
import mlx.core as mx
729+
import numpy as np
730+
731+
self.mx = mx
732+
self.np = np
733+
734+
def is_appropriate_type(self, tensor):
735+
return isinstance(tensor, self.mx.array)
736+
737+
def from_numpy(self, x):
738+
return self.mx.array(x)
739+
740+
def to_numpy(self, x):
741+
if x.dtype == self.mx.bfloat16:
742+
x = x.astype(self.mx.float32)
743+
return self.np.array(x)
744+
745+
def arange(self, start, stop):
746+
return self.mx.arange(start, stop)
747+
748+
def stack_on_zeroth_dimension(self, tensors: list):
749+
return self.mx.stack(tensors)
750+
751+
def add_axes(self, x, new_position):
752+
return self.mx.expand_dims(x, new_position)
753+
754+
def tile(self, x, repeats):
755+
return self.mx.tile(x, repeats)
756+
757+
def concat(self, tensors, axis: int):
758+
return self.mx.concatenate(tensors, axis=axis)
759+
760+
def is_float_type(self, x):
761+
return self.mx.issubdtype(x.dtype, self.mx.floating)
762+
763+
def einsum(self, pattern, *x):
764+
return self.mx.einsum(pattern, *x)

einops/tests/run_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def main():
3434
"paddle": ["paddlepaddle"],
3535
"oneflow": ["oneflow==0.9.0"],
3636
"pytensor": ["pytensor"],
37+
"mlx": ["mlx"],
3738
}
3839

3940
usage = f"""

einops/tests/test_einsum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def test_layer():
191191
"tensorflow.keras",
192192
"paddle",
193193
"pytensor",
194+
"mlx",
194195
]
195196

196197

0 commit comments

Comments
 (0)