Skip to content

Commit b9f7141

Browse files
authored
Implemented segment_sum and segment_max (#20856)
1 parent 6a957bc commit b9f7141

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

keras/src/backend/mlx/math.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,54 @@
11
import math
22

33
import mlx.core as mx
4+
import numpy as np
45

56
from keras.src.backend.mlx.core import convert_to_tensor
67

78

9+
def _segment_reduction_fn(
10+
data, segment_ids, reduction_method, num_segments, sorted
11+
):
12+
data = convert_to_tensor(data)
13+
segment_ids = convert_to_tensor(segment_ids)
14+
15+
if data.dtype == mx.int64:
16+
# GPU scatter does not yet support int64 for the input or updates.
17+
data = data.astype(mx.int32)
18+
19+
if num_segments is None:
20+
num_segments = mx.max(segment_ids) + 1
21+
22+
valid_indices = segment_ids >= 0
23+
valid_data = mx.array(
24+
np.array(data)[valid_indices] # MLX does not support boolean indices
25+
)
26+
valid_segment_ids = mx.array(np.array(segment_ids)[valid_indices])
27+
28+
data_shape = list(valid_data.shape)
29+
data_shape[0] = num_segments
30+
31+
if not sorted:
32+
sort_indices = mx.argsort(valid_segment_ids)
33+
valid_segment_ids = valid_segment_ids[sort_indices]
34+
valid_data = valid_data[sort_indices]
35+
36+
if reduction_method == "max":
37+
result = mx.ones(data_shape, dtype=valid_data.dtype) * -mx.inf
38+
result = result.at[valid_segment_ids].maximum(valid_data)
39+
else: # sum
40+
result = mx.zeros(data_shape, dtype=valid_data.dtype)
41+
result = result.at[valid_segment_ids].add(valid_data)
42+
43+
return result
44+
45+
846
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
9-
raise NotImplementedError("segment_sum is not implemented for mlx")
47+
return _segment_reduction_fn(data, segment_ids, "sum", num_segments, sorted)
1048

1149

1250
def segment_max(data, segment_ids, num_segments=None, sorted=False):
13-
raise NotImplementedError("segment_max is not implemented for mlx")
51+
return _segment_reduction_fn(data, segment_ids, "max", num_segments, sorted)
1452

1553

1654
def top_k(x, k, sorted=True):

0 commit comments

Comments
 (0)