Skip to content

Commit 168f13d

Browse files
authored
mlx - patch for pytest to run (#20825)
* patch for pytest to run with mlx * formatting
1 parent 179ebeb commit 168f13d

File tree

5 files changed

+21
-0
lines changed

5 files changed

+21
-0
lines changed

keras/src/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
distribution_lib = None
5858
elif backend() == "mlx":
5959
from keras.src.backend.mlx import * # noqa: F403
60+
from keras.src.backend.mlx.core import Variable as BackendVariable
6061

6162
distribution_lib = None
6263
else:

keras/src/backend/mlx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""MLX backend APIs."""
22

3+
from keras.src.backend.common.name_scope import name_scope
34
from keras.src.backend.mlx import core
45
from keras.src.backend.mlx import image
56
from keras.src.backend.mlx import linalg

keras/src/backend/mlx/export.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class MlxExportArchive:
2+
def track(self, resource):
3+
raise NotImplementedError(
4+
"`track` is not implemented in the mlx backend."
5+
)
6+
7+
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
8+
raise NotImplementedError(
9+
"`add_endpoint` is not implemented in the mlx backend."
10+
)

keras/src/backend/mlx/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def gelu_tanh_approx(x):
115115
return f(x)
116116

117117

118+
def celu(x, alpha=1.0):
119+
x = convert_to_tensor(x)
120+
return nn.celu(x, alpha=alpha)
121+
122+
118123
def softmax(x, axis=-1):
119124
x = convert_to_tensor(x)
120125
return mx.softmax(x, axis=axis)

keras/src/export/saved_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from keras.src.backend.openvino.export import (
3030
OpenvinoExportArchive as BackendExportArchive,
3131
)
32+
elif backend.backend() == "mlx":
33+
from keras.src.backend.mlx.export import (
34+
MlxExportArchive as BackendExportArchive,
35+
)
3236
else:
3337
raise RuntimeError(
3438
f"Backend '{backend.backend()}' must implement a layer mixin class."

0 commit comments

Comments
 (0)