1
1
import math
2
- import operator
3
2
4
3
import mlx .core as mx
5
- import numpy as np
6
4
7
5
from keras .src .backend import standardize_dtype
6
+ from keras .src .backend .common .backend_utils import canonicalize_axis
8
7
from keras .src .backend .mlx .core import convert_to_tensor
9
8
from keras .src .backend .mlx .linalg import det
10
9
from keras .src .utils .module_utils import scipy
@@ -23,26 +22,31 @@ def _segment_reduction_fn(
23
22
if num_segments is None :
24
23
num_segments = mx .max (segment_ids ) + 1
25
24
26
- valid_indices = segment_ids >= 0
27
- valid_data = mx .array (
28
- np .array (data )[valid_indices ] # MLX does not support boolean indices
29
- )
30
- valid_segment_ids = mx .array (np .array (segment_ids )[valid_indices ])
31
-
32
- data_shape = list (valid_data .shape )
33
- data_shape [0 ] = num_segments
25
+ mask = segment_ids >= 0
26
+ # pack segment_ids < 0 into index 0 and then handle below
27
+ safe_segment_ids = mx .where (mask , segment_ids , 0 )
34
28
35
29
if not sorted :
36
- sort_indices = mx .argsort (valid_segment_ids )
37
- valid_segment_ids = valid_segment_ids [sort_indices ]
38
- valid_data = valid_data [sort_indices ]
30
+ sort_indices = mx .argsort (safe_segment_ids )
31
+ safe_segment_ids = mx .take (safe_segment_ids , sort_indices )
32
+ data = mx .take (data , sort_indices , axis = 0 )
33
+ mask = mx .take (mask , sort_indices )
34
+
35
+ # expand mask dimensions to match data dimensions
36
+ for i in range (1 , len (data .shape )):
37
+ mask = mx .expand_dims (mask , axis = i )
38
+
39
+ data_shape = list (data .shape )
40
+ data_shape [0 ] = num_segments
39
41
40
42
if reduction_method == "max" :
41
- result = mx .ones (data_shape , dtype = valid_data .dtype ) * - mx .inf
42
- result = result .at [valid_segment_ids ].maximum (valid_data )
43
+ masked_data = mx .where (mask , data , - mx .inf )
44
+ result = mx .ones (data_shape , dtype = data .dtype ) * - mx .inf
45
+ result = result .at [safe_segment_ids ].maximum (masked_data )
43
46
else : # sum
44
- result = mx .zeros (data_shape , dtype = valid_data .dtype )
45
- result = result .at [valid_segment_ids ].add (valid_data )
47
+ masked_data = mx .where (mask , data , 0 )
48
+ result = mx .zeros (data_shape , dtype = data .dtype )
49
+ result = result .at [safe_segment_ids ].add (masked_data )
46
50
47
51
return result
48
52
@@ -154,19 +158,6 @@ def irfft(x, fft_length=None):
154
158
return real_output
155
159
156
160
157
- def _canonicalize_axis (axis , num_dims ):
158
- # Ref: jax.scipy.signal.stft
159
- """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
160
- axis = operator .index (axis )
161
- if not - num_dims <= axis < num_dims :
162
- raise ValueError (
163
- f"axis { axis } is out of bounds for array of dimension { num_dims } "
164
- )
165
- if axis < 0 :
166
- axis = axis + num_dims
167
- return axis
168
-
169
-
170
161
def _create_sliding_windows (x , window_size , step ):
171
162
batch_size , signal_length , _ = x .shape
172
163
num_windows = (signal_length - window_size ) // step + 1
@@ -187,7 +178,7 @@ def _create_sliding_windows(x, window_size, step):
187
178
188
179
def _stft (x , window , nperseg , noverlap , nfft , axis = - 1 ):
189
180
# Ref: jax.scipy.signal.stft
190
- axis = _canonicalize_axis (axis , x .ndim )
181
+ axis = canonicalize_axis (axis , x .ndim )
191
182
result_dtype = mx .complex64
192
183
193
184
if x .size == 0 :
@@ -364,8 +355,8 @@ def _istft(
364
355
# Ref: jax.scipy.signal.istft
365
356
if Zxx .ndim < 2 :
366
357
raise ValueError ("Input stft must be at least 2d!" )
367
- freq_axis = _canonicalize_axis (freq_axis , Zxx .ndim )
368
- time_axis = _canonicalize_axis (time_axis , Zxx .ndim )
358
+ freq_axis = canonicalize_axis (freq_axis , Zxx .ndim )
359
+ time_axis = canonicalize_axis (time_axis , Zxx .ndim )
369
360
370
361
if freq_axis == time_axis :
371
362
raise ValueError ("Must specify differing time and frequency axes!" )
0 commit comments