1
1
import builtins
2
+ import math
2
3
from copy import copy as builtin_copy
3
4
4
5
import mlx .core as mx
@@ -950,7 +951,6 @@ def quantile(x, q, axis=None, method="linear", keepdims=False):
950
951
else :
951
952
dtype = dtypes .result_type (x .dtype , float )
952
953
mlx_dtype = to_mlx_dtype (dtype )
953
- print ("mlx_dtype" , mlx_dtype )
954
954
955
955
# problem casting mlx bfloat16 array to numpy
956
956
if ori_dtype == "bfloat16" :
@@ -1374,8 +1374,43 @@ def wrapped(*args):
1374
1374
return wrapped
1375
1375
1376
1376
1377
- def histogram (x , bins , range ):
1378
- raise NotImplementedError ("histogram not yet implemented in mlx." )
1377
+ def histogram_bin_edges (a , bins = 10 , range = None ):
1378
+ # Ref: jax.numpy.histogram
1379
+ # infer range if None
1380
+ if range is None :
1381
+ range = (mx .min (a ).item (), mx .max (a ).item ())
1382
+
1383
+ if range [0 ] == range [1 ]:
1384
+ range = (range [0 ] - 0.5 , range [1 ] + 0.5 )
1385
+
1386
+ bin_edges = mx .linspace (range [0 ], range [1 ], bins + 1 , dtype = mx .float32 )
1387
+ # due to the way mlx currently handles linspace
1388
+ # with fp32 precision it is not always right edge inclusive
1389
+ # manually set the right edge for now
1390
+ bin_edges [- 1 ] = range [- 1 ]
1391
+ return bin_edges
1392
+
1393
+
1394
+ def histogram (x , bins = 10 , range = None ):
1395
+ # Ref: jax.numpy.histogram
1396
+ x = convert_to_tensor (x )
1397
+ if range is not None :
1398
+ if not isinstance (range , tuple ) or len (range ) != 2 :
1399
+ raise ValueError (
1400
+ "Invalid value for argument `range`. Only `None` or "
1401
+ "a tuple of the lower and upper range of bins is supported. "
1402
+ f"Received: range={ range } "
1403
+ )
1404
+
1405
+ bin_edges = histogram_bin_edges (x , bins , range )
1406
+
1407
+ bin_idx = searchsorted (bin_edges , x , side = "right" )
1408
+ bin_idx = mx .where (x == bin_edges [- 1 ], len (bin_edges ) - 1 , bin_idx )
1409
+
1410
+ counts = mx .zeros (len (bin_edges ))
1411
+ counts = counts .at [bin_idx ].add (mx .ones_like (x ))
1412
+
1413
+ return counts [1 :], bin_edges
1379
1414
1380
1415
1381
1416
def unravel_index (x , shape ):
@@ -1384,7 +1419,7 @@ def unravel_index(x, shape):
1384
1419
1385
1420
if None in shape :
1386
1421
raise ValueError (
1387
- "`shape` argument cannot contain `None`. Received: shape={shape}"
1422
+ f "`shape` argument cannot contain `None`. Received: shape={ shape } "
1388
1423
)
1389
1424
1390
1425
if x .ndim == 1 :
@@ -1403,8 +1438,73 @@ def unravel_index(x, shape):
1403
1438
return tuple (reversed (coords ))
1404
1439
1405
1440
1441
+ def searchsorted_binary (a , b , side = "left" ):
1442
+ original_shape = b .shape
1443
+ b_flat = b .reshape (- 1 )
1444
+
1445
+ size = a .shape [0 ]
1446
+ steps = math .ceil (math .log2 (size ))
1447
+ indices = mx .full (b_flat .shape , vals = size // 2 , dtype = mx .uint32 )
1448
+
1449
+ comparison = lambda x , y : x <= y if side == "left" else lambda x , y : x < y
1450
+
1451
+ upper = size
1452
+ lower = 0
1453
+ for _ in range (steps ):
1454
+ comp = comparison (b_flat , a [indices ])
1455
+ new_indices = mx .where (
1456
+ comp , (lower + indices ) // 2 , (indices + upper ) // 2
1457
+ )
1458
+ lower = mx .where (comp , lower , indices )
1459
+ upper = mx .where (comp , indices , upper )
1460
+ indices = new_indices
1461
+
1462
+ result = mx .where (comparison (b_flat , a [indices ]), indices , indices + 1 )
1463
+ return result .reshape (original_shape )
1464
+
1465
+
1466
+ def searchsorted_linear (a , b , side = "left" ):
1467
+ original_shape = b .shape
1468
+ b_flat = b .reshape (- 1 )
1469
+ b_flat_broadcast = b_flat .reshape (- 1 , 1 )
1470
+ if side == "left" :
1471
+ result = (a [None , :] < b_flat_broadcast ).sum (axis = 1 )
1472
+ else :
1473
+ result = (a [None , :] <= b_flat_broadcast ).sum (axis = 1 )
1474
+
1475
+ return result .reshape (original_shape )
1476
+
1477
+
1406
1478
def searchsorted (sorted_sequence , values , side = "left" ):
1407
- raise NotImplementedError ("searchsorted not yet implemented in mlx." )
1479
+ if side not in ("left" , "right" ):
1480
+ raise ValueError (f"Invalid side `{ side } `, must be `left` or `right`." )
1481
+ sorted_sequence = convert_to_tensor (sorted_sequence )
1482
+ values = convert_to_tensor (values )
1483
+ if sorted_sequence .ndim != 1 :
1484
+ raise ValueError (
1485
+ "Invalid sorted_sequence, should be 1-dimensional. "
1486
+ f"Recieved sorted_sequence.shape={ sorted_sequence .shape } "
1487
+ )
1488
+ if values .ndim == 0 :
1489
+ raise ValueError (
1490
+ "Invalid values, should be N-dimensional. Recieved "
1491
+ f"scalar array values.shape={ values .shape } "
1492
+ )
1493
+
1494
+ sorted_size = sorted_sequence .size
1495
+ search_size = values .size
1496
+
1497
+ # TODO: swap to mlx implementation if exists in the future
1498
+ # current implementation and search choice based on discussion:
1499
+ # https://github.com/ml-explore/mlx/issues/1255
1500
+ use_linear = sorted_size <= 1024 or (
1501
+ sorted_size <= 16384 and search_size <= 256
1502
+ )
1503
+
1504
+ if use_linear :
1505
+ return searchsorted_linear (sorted_sequence , values , side = side )
1506
+ else :
1507
+ return searchsorted_binary (sorted_sequence , values , side = side )
1408
1508
1409
1509
1410
1510
def diagflat (x , k = 0 ):
0 commit comments