|
1 | 1 | import math
|
2 | 2 |
|
3 | 3 | import mlx.core as mx
|
| 4 | +import numpy as np |
4 | 5 |
|
5 | 6 | from keras.src.backend.mlx.core import convert_to_tensor
|
6 | 7 |
|
7 | 8 |
|
| 9 | +def _segment_reduction_fn( |
| 10 | + data, segment_ids, reduction_method, num_segments, sorted |
| 11 | +): |
| 12 | + data = convert_to_tensor(data) |
| 13 | + segment_ids = convert_to_tensor(segment_ids) |
| 14 | + |
| 15 | + if data.dtype == mx.int64: |
| 16 | + # GPU scatter does not yet support int64 for the input or updates. |
| 17 | + data = data.astype(mx.int32) |
| 18 | + |
| 19 | + if num_segments is None: |
| 20 | + num_segments = mx.max(segment_ids) + 1 |
| 21 | + |
| 22 | + valid_indices = segment_ids >= 0 |
| 23 | + valid_data = mx.array( |
| 24 | + np.array(data)[valid_indices] # MLX does not support boolean indices |
| 25 | + ) |
| 26 | + valid_segment_ids = mx.array(np.array(segment_ids)[valid_indices]) |
| 27 | + |
| 28 | + data_shape = list(valid_data.shape) |
| 29 | + data_shape[0] = num_segments |
| 30 | + |
| 31 | + if not sorted: |
| 32 | + sort_indices = mx.argsort(valid_segment_ids) |
| 33 | + valid_segment_ids = valid_segment_ids[sort_indices] |
| 34 | + valid_data = valid_data[sort_indices] |
| 35 | + |
| 36 | + if reduction_method == "max": |
| 37 | + result = mx.ones(data_shape, dtype=valid_data.dtype) * -mx.inf |
| 38 | + result = result.at[valid_segment_ids].maximum(valid_data) |
| 39 | + else: # sum |
| 40 | + result = mx.zeros(data_shape, dtype=valid_data.dtype) |
| 41 | + result = result.at[valid_segment_ids].add(valid_data) |
| 42 | + |
| 43 | + return result |
| 44 | + |
| 45 | + |
8 | 46 | def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
9 |
| - raise NotImplementedError("segment_sum is not implemented for mlx") |
| 47 | + return _segment_reduction_fn(data, segment_ids, "sum", num_segments, sorted) |
10 | 48 |
|
11 | 49 |
|
12 | 50 | def segment_max(data, segment_ids, num_segments=None, sorted=False):
|
13 |
| - raise NotImplementedError("segment_max is not implemented for mlx") |
| 51 | + return _segment_reduction_fn(data, segment_ids, "max", num_segments, sorted) |
14 | 52 |
|
15 | 53 |
|
16 | 54 | def top_k(x, k, sorted=True):
|
|
0 commit comments