Skip to content

Commit aad52f1

Browse files
authored
mlx - numpy.searchsorted and numpy.historgram implemented (#20927)
* numpy.searchsorted and numpy.historgram * clean comments
1 parent 45e5771 commit aad52f1

File tree

6 files changed

+224
-8
lines changed

6 files changed

+224
-8
lines changed

.github/workflows/actions.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ jobs:
7878
if: ${{ matrix.backend == 'jax'}}
7979
run: |
8080
python integration_tests/jax_custom_fit_test.py
81+
- name: Test MLX-specific integrations
82+
if: ${{ matrix.backend == 'mlx'}}
83+
run: |
84+
python integration_tests/mlx_custom_fit_test.py
8185
- name: Test TF-specific integrations
8286
if: ${{ matrix.backend == 'tensorflow'}}
8387
run: |
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import mlx.core as mx
2+
import numpy as np
3+
4+
import keras
5+
6+
7+
def test_custom_fit():
8+
class CustomModel(keras.Model):
9+
def __init__(self, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
self.loss_tracker = keras.metrics.Mean(name="loss")
12+
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
13+
self.loss_fn = keras.losses.MeanSquaredError()
14+
15+
def compute_loss_and_updates(
16+
self,
17+
trainable_variables,
18+
non_trainable_variables,
19+
x,
20+
y,
21+
training=False,
22+
):
23+
y_pred, non_trainable_variables = self.stateless_call(
24+
trainable_variables,
25+
non_trainable_variables,
26+
x,
27+
training=training,
28+
)
29+
loss = self.loss_fn(y, y_pred)
30+
return loss, (y_pred, non_trainable_variables)
31+
32+
def train_step(self, state, data):
33+
(
34+
trainable_variables,
35+
non_trainable_variables,
36+
optimizer_variables,
37+
metrics_variables,
38+
) = state
39+
x, y = data
40+
grad_fn = mx.value_and_grad(self.compute_loss_and_updates)
41+
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
42+
trainable_variables,
43+
non_trainable_variables,
44+
x,
45+
y,
46+
training=True,
47+
)
48+
(
49+
trainable_variables,
50+
optimizer_variables,
51+
) = self.optimizer.stateless_apply(
52+
optimizer_variables, grads, trainable_variables
53+
)
54+
loss_tracker_vars = metrics_variables[
55+
: len(self.loss_tracker.variables)
56+
]
57+
mae_metric_vars = metrics_variables[
58+
len(self.loss_tracker.variables) :
59+
]
60+
loss_tracker_vars = self.loss_tracker.stateless_update_state(
61+
loss_tracker_vars, loss
62+
)
63+
mae_metric_vars = self.mae_metric.stateless_update_state(
64+
mae_metric_vars, y, y_pred
65+
)
66+
logs = {}
67+
logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
68+
loss_tracker_vars
69+
)
70+
logs[self.mae_metric.name] = self.mae_metric.stateless_result(
71+
mae_metric_vars
72+
)
73+
new_metrics_vars = loss_tracker_vars + mae_metric_vars
74+
state = (
75+
trainable_variables,
76+
non_trainable_variables,
77+
optimizer_variables,
78+
new_metrics_vars,
79+
)
80+
return logs, state
81+
82+
@property
83+
def metrics(self):
84+
return [self.loss_tracker, self.mae_metric]
85+
86+
inputs = keras.Input(shape=(32,))
87+
outputs = keras.layers.Dense(1)(inputs)
88+
model = CustomModel(inputs, outputs)
89+
model.compile(optimizer="adam")
90+
x = np.random.random((64, 32))
91+
y = np.random.random((64, 1))
92+
history = model.fit(x, y, epochs=1)
93+
94+
assert "loss" in history.history
95+
assert "mae" in history.history
96+
97+
print("History:")
98+
print(history.history)
99+
100+
101+
if __name__ == "__main__":
102+
test_custom_fit()

keras/src/backend/common/dtypes_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ class DtypesTest(test_case.TestCase):
2929
for x in dtypes.ALLOWED_DTYPES
3030
if x not in ["string", "complex64", "complex128"]
3131
] + [None]
32+
elif backend.backend() == "mlx":
33+
ALL_DTYPES = [
34+
x
35+
for x in dtypes.ALLOWED_DTYPES
36+
if x not in ["string", "complex128"]
37+
] + [None]
3238
else:
3339
ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [
3440
None

keras/src/backend/mlx/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
MLX_DTYPES = {
2626
"float16": mx.float16,
2727
"float32": mx.float32,
28-
"float64": None, # mlx does not support float64
28+
"float64": None, # mlx only supports float64 on cpu
2929
"uint8": mx.uint8,
3030
"uint16": mx.uint16,
3131
"uint32": mx.uint32,

keras/src/backend/mlx/numpy.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import builtins
2+
import math
23
from copy import copy as builtin_copy
34

45
import mlx.core as mx
@@ -950,7 +951,6 @@ def quantile(x, q, axis=None, method="linear", keepdims=False):
950951
else:
951952
dtype = dtypes.result_type(x.dtype, float)
952953
mlx_dtype = to_mlx_dtype(dtype)
953-
print("mlx_dtype", mlx_dtype)
954954

955955
# problem casting mlx bfloat16 array to numpy
956956
if ori_dtype == "bfloat16":
@@ -1374,8 +1374,43 @@ def wrapped(*args):
13741374
return wrapped
13751375

13761376

1377-
def histogram(x, bins, range):
1378-
raise NotImplementedError("histogram not yet implemented in mlx.")
1377+
def histogram_bin_edges(a, bins=10, range=None):
1378+
# Ref: jax.numpy.histogram
1379+
# infer range if None
1380+
if range is None:
1381+
range = (mx.min(a).item(), mx.max(a).item())
1382+
1383+
if range[0] == range[1]:
1384+
range = (range[0] - 0.5, range[1] + 0.5)
1385+
1386+
bin_edges = mx.linspace(range[0], range[1], bins + 1, dtype=mx.float32)
1387+
# due to the way mlx currently handles linspace
1388+
# with fp32 precision it is not always right edge inclusive
1389+
# manually set the right edge for now
1390+
bin_edges[-1] = range[-1]
1391+
return bin_edges
1392+
1393+
1394+
def histogram(x, bins=10, range=None):
1395+
# Ref: jax.numpy.histogram
1396+
x = convert_to_tensor(x)
1397+
if range is not None:
1398+
if not isinstance(range, tuple) or len(range) != 2:
1399+
raise ValueError(
1400+
"Invalid value for argument `range`. Only `None` or "
1401+
"a tuple of the lower and upper range of bins is supported. "
1402+
f"Received: range={range}"
1403+
)
1404+
1405+
bin_edges = histogram_bin_edges(x, bins, range)
1406+
1407+
bin_idx = searchsorted(bin_edges, x, side="right")
1408+
bin_idx = mx.where(x == bin_edges[-1], len(bin_edges) - 1, bin_idx)
1409+
1410+
counts = mx.zeros(len(bin_edges))
1411+
counts = counts.at[bin_idx].add(mx.ones_like(x))
1412+
1413+
return counts[1:], bin_edges
13791414

13801415

13811416
def unravel_index(x, shape):
@@ -1384,7 +1419,7 @@ def unravel_index(x, shape):
13841419

13851420
if None in shape:
13861421
raise ValueError(
1387-
"`shape` argument cannot contain `None`. Received: shape={shape}"
1422+
f"`shape` argument cannot contain `None`. Received: shape={shape}"
13881423
)
13891424

13901425
if x.ndim == 1:
@@ -1403,8 +1438,73 @@ def unravel_index(x, shape):
14031438
return tuple(reversed(coords))
14041439

14051440

1441+
def searchsorted_binary(a, b, side="left"):
1442+
original_shape = b.shape
1443+
b_flat = b.reshape(-1)
1444+
1445+
size = a.shape[0]
1446+
steps = math.ceil(math.log2(size))
1447+
indices = mx.full(b_flat.shape, vals=size // 2, dtype=mx.uint32)
1448+
1449+
comparison = lambda x, y: x <= y if side == "left" else lambda x, y: x < y
1450+
1451+
upper = size
1452+
lower = 0
1453+
for _ in range(steps):
1454+
comp = comparison(b_flat, a[indices])
1455+
new_indices = mx.where(
1456+
comp, (lower + indices) // 2, (indices + upper) // 2
1457+
)
1458+
lower = mx.where(comp, lower, indices)
1459+
upper = mx.where(comp, indices, upper)
1460+
indices = new_indices
1461+
1462+
result = mx.where(comparison(b_flat, a[indices]), indices, indices + 1)
1463+
return result.reshape(original_shape)
1464+
1465+
1466+
def searchsorted_linear(a, b, side="left"):
1467+
original_shape = b.shape
1468+
b_flat = b.reshape(-1)
1469+
b_flat_broadcast = b_flat.reshape(-1, 1)
1470+
if side == "left":
1471+
result = (a[None, :] < b_flat_broadcast).sum(axis=1)
1472+
else:
1473+
result = (a[None, :] <= b_flat_broadcast).sum(axis=1)
1474+
1475+
return result.reshape(original_shape)
1476+
1477+
14061478
def searchsorted(sorted_sequence, values, side="left"):
1407-
raise NotImplementedError("searchsorted not yet implemented in mlx.")
1479+
if side not in ("left", "right"):
1480+
raise ValueError(f"Invalid side `{side}`, must be `left` or `right`.")
1481+
sorted_sequence = convert_to_tensor(sorted_sequence)
1482+
values = convert_to_tensor(values)
1483+
if sorted_sequence.ndim != 1:
1484+
raise ValueError(
1485+
"Invalid sorted_sequence, should be 1-dimensional. "
1486+
f"Recieved sorted_sequence.shape={sorted_sequence.shape}"
1487+
)
1488+
if values.ndim == 0:
1489+
raise ValueError(
1490+
"Invalid values, should be N-dimensional. Recieved "
1491+
f"scalar array values.shape={values.shape}"
1492+
)
1493+
1494+
sorted_size = sorted_sequence.size
1495+
search_size = values.size
1496+
1497+
# TODO: swap to mlx implementation if exists in the future
1498+
# current implementation and search choice based on discussion:
1499+
# https://github.com/ml-explore/mlx/issues/1255
1500+
use_linear = sorted_size <= 1024 or (
1501+
sorted_size <= 16384 and search_size <= 256
1502+
)
1503+
1504+
if use_linear:
1505+
return searchsorted_linear(sorted_sequence, values, side=side)
1506+
else:
1507+
return searchsorted_binary(sorted_sequence, values, side=side)
14081508

14091509

14101510
def diagflat(x, k=0):

keras/src/backend/mlx/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ def train_step(self, state, data):
260260
for ref_v, v in zip(self.metrics_variables, metrics_variables)
261261
]
262262
) as scope:
263-
self._loss_tracker.update_state(unscaled_loss)
263+
self._loss_tracker.update_state(
264+
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
265+
)
264266
logs = self.compute_metrics(x, y, y_pred, sample_weight)
265267

266268
new_metrics_variables = []
@@ -553,6 +555,7 @@ def fit(
553555

554556
self.stop_training = False
555557
self.make_train_function()
558+
training_logs = {}
556559
callbacks.on_train_begin()
557560
initial_epoch = self._initial_epoch or initial_epoch
558561
for epoch in range(initial_epoch, epochs):
@@ -648,6 +651,7 @@ def fit(
648651
# If _eval_epoch_iterator exists, delete it after all epochs are done.
649652
if getattr(self, "_eval_epoch_iterator", None) is not None:
650653
del self._eval_epoch_iterator
654+
651655
callbacks.on_train_end(logs=training_logs)
652656
self._mlx_state = None
653657
return self.history
@@ -706,7 +710,7 @@ def evaluate(
706710
self.make_test_function()
707711
self.stop_evaluating = False
708712
callbacks.on_test_begin()
709-
logs = None
713+
logs = {}
710714
self.reset_metrics()
711715

712716
trainable_variables = [v.value for v in self.trainable_variables]

0 commit comments

Comments
 (0)