Skip to content

Commit 47c032d

Browse files
authored
implement batch_normalization (#19543)
1 parent fb6244b commit 47c032d

File tree

5 files changed

+29
-7
lines changed

5 files changed

+29
-7
lines changed

keras/backend/exports.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
elif backend.backend() == "jax":
88
BackendVariable = backend.jax.core.Variable
99
backend_name_scope = backend.common.name_scope.name_scope
10+
elif backend.backend() == "mlx":
11+
BackendVariable = backend.mlx.core.Variable
12+
backend_name_scope = backend.common.name_scope.name_scope
1013
elif backend.backend() == "torch":
1114
BackendVariable = backend.torch.core.Variable
1215
backend_name_scope = backend.common.name_scope.name_scope

keras/backend/mlx/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import mlx.core as mx
22
import numpy as np
3-
import tree
3+
from keras.utils import tree
44

55
from keras.backend.common import KerasVariable
66
from keras.backend.common import standardize_dtype
77
from keras.backend.common.keras_tensor import KerasTensor
88
from keras.backend.common.stateless_scope import StatelessScope
9-
from keras.utils.nest import pack_sequence_as
9+
from keras.utils.tree import pack_sequence_as
1010

1111
SUPPORTS_SPARSE_TENSORS = False
1212

keras/backend/mlx/nn.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,22 @@ def moments(x, axes, keepdims=False, synchronized=False):
304304
def batch_normalization(
305305
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
306306
):
307-
raise NotImplementedError(
308-
"MLX backend doesn't support batch normalization yet."
309-
)
307+
shape = [1] * len(x.shape)
308+
shape[axis] = mean.shape[0]
309+
mean = mx.reshape(mean, shape)
310+
variance = mx.reshape(variance, shape)
311+
312+
inv = mx.rsqrt(variance + epsilon)
313+
if scale is not None:
314+
scale = mx.reshape(scale, shape)
315+
inv = inv * scale
316+
317+
res = -mean * inv
318+
if offset is not None:
319+
offset = mx.reshape(offset, shape)
320+
res = res + offset
321+
322+
return mx.add(x * inv, res)
310323

311324

312325
def ctc_loss(

keras/backend/mlx/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,12 @@ def divide(x1, x2):
900900
return mx.divide(x1, x2)
901901

902902

903+
def divide_no_nan(x1, x2):
904+
x1 = convert_to_tensor(x1)
905+
x2 = convert_to_tensor(x2)
906+
return mx.where(x2 == 0, 0, mx.divide(x1, x2))
907+
908+
903909
def true_divide(x1, x2):
904910
return divide(x1, x2)
905911

keras/backend/mlx/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import mlx.core as mx
22
import numpy as np
3-
import tree
3+
from keras.utils import tree
44

55
from keras import backend
66
from keras import callbacks as callbacks_module
@@ -141,7 +141,7 @@ def compute_loss_and_updates(
141141
# Note that this is needed for the regularization loss, which need
142142
# the latest value of train/non-trainable variables.
143143
loss = self.compute_loss(
144-
x, y, y_pred, sample_weight, allow_empty=True
144+
x, y, y_pred, sample_weight
145145
)
146146
if losses:
147147
loss += ops.sum(losses)

0 commit comments

Comments
 (0)