@@ -1356,26 +1356,53 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp,
1356
1356
# print("output: {0}".format(output))
1357
1357
return output
1358
1358
1359
+ def _get_interpolation_id (str interpolation ):
1360
+ """
1361
+ Converts string to interpolation id
1362
+
1363
+ Parameters
1364
+ ----------
1365
+ interpolation: 'linear', 'lower', 'higher', 'nearest', 'midpoint'
1366
+ """
1367
+ if interpolation == ' linear' :
1368
+ return 0
1369
+ elif interpolation == ' lower' :
1370
+ return 1
1371
+ elif interpolation == ' higher' :
1372
+ return 2
1373
+ elif interpolation == ' nearest' :
1374
+ return 3
1375
+ elif interpolation == ' midpoint' :
1376
+ return 4
1377
+ else :
1378
+ raise ValueError (" Interpolation {} is not supported"
1379
+ .format(interpolation))
1380
+
1359
1381
1360
1382
def roll_quantile (ndarray[float64_t , cast = True ] input , int64_t win ,
1361
1383
int64_t minp , object index , object closed ,
1362
- double quantile ):
1384
+ double quantile , str interpolation ):
1363
1385
"""
1364
1386
O(N log(window)) implementation using skip list
1365
1387
"""
1366
1388
cdef:
1367
- double val, prev, midpoint
1389
+ double val, prev, midpoint, idx_with_fraction
1368
1390
IndexableSkiplist skiplist
1369
1391
int64_t nobs = 0 , i, j, s, e, N
1370
1392
Py_ssize_t idx
1371
1393
bint is_variable
1372
1394
ndarray[int64_t] start, end
1373
1395
ndarray[double_t] output
1374
1396
double vlow, vhigh
1397
+ int interpolation_id
1375
1398
1376
1399
if quantile <= 0.0 or quantile >= 1.0 :
1377
1400
raise ValueError (" quantile value {0} not in [0, 1]" .format(quantile))
1378
1401
1402
+ # interpolation_id is needed to avoid string comparisons inside the loop
1403
+ # I tried to use callback but it resulted in worse performance
1404
+ interpolation_id = _get_interpolation_id(interpolation)
1405
+
1379
1406
# we use the Fixed/Variable Indexer here as the
1380
1407
# actual skiplist ops outweigh any window computation costs
1381
1408
start, end, N, win, minp, is_variable = get_window_indexer(
@@ -1414,18 +1441,31 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
1414
1441
skiplist.insert(val)
1415
1442
1416
1443
if nobs >= minp:
1417
- idx = int (quantile * < double > (nobs - 1 ))
1418
-
1419
- # Single value in skip list
1420
1444
if nobs == 1 :
1445
+ # Single value in skip list
1421
1446
output[i] = skiplist.get(0 )
1422
-
1423
- # Interpolated quantile
1424
1447
else :
1425
- vlow = skiplist.get(idx)
1426
- vhigh = skiplist.get(idx + 1 )
1427
- output[i] = ((vlow + (vhigh - vlow) *
1428
- (quantile * (nobs - 1 ) - idx)))
1448
+ idx_with_fraction = quantile * < double > (nobs - 1 )
1449
+ idx = int (idx_with_fraction)
1450
+
1451
+ if interpolation_id == 0 : # linear
1452
+ vlow = skiplist.get(idx)
1453
+ vhigh = skiplist.get(idx + 1 )
1454
+ output[i] = ((vlow + (vhigh - vlow) *
1455
+ (idx_with_fraction - idx)))
1456
+ elif interpolation_id == 1 : # lower
1457
+ output[i] = skiplist.get(idx)
1458
+ elif interpolation_id == 2 : # higher
1459
+ output[i] = skiplist.get(idx + 1 )
1460
+ elif interpolation_id == 3 : # nearest
1461
+ if idx_with_fraction - idx < 0.5 :
1462
+ output[i] = skiplist.get(idx)
1463
+ else :
1464
+ output[i] = skiplist.get(idx + 1 )
1465
+ elif interpolation_id == 4 : # midpoint
1466
+ vlow = skiplist.get(idx)
1467
+ vhigh = skiplist.get(idx + 1 )
1468
+ output[i] = < double > (vlow + vhigh) / 2
1429
1469
else :
1430
1470
output[i] = NaN
1431
1471
0 commit comments