-
I just found a way to compute quantiles faster than via "jax.quantile" method (at least with my nvidia GPU, large arrays and when we need to compute only one quantile per array). It is just simple binary search. It can be few times faster when using precise calculation with interpolation, like it is done in jax and numpy, but if we don't need perfect accuracy, the speed can be further increased by selecting "method='fast'" and specifying "step" equal to the permissible error in returned quantile. Here is the code: from time import time
from datetime import datetime as dt
from jax import jit, numpy as jnp, random, lax, config
from functools import partial
def test():
config.update('jax_enable_x64', True)
# config.update('jax_platform_name', 'cpu')
x = random.normal(random.key( int( time() * 1e3) ), 10000000)
binQuantile(x, .5) # compiling
qs = jnp.arange(.1, 1., .001)
t = dt.now()
res = lax.scan(lambda r, q: (r + jnp.quantile(x, q, method='linear'), None), 0., qs)[0]
print(f'jnp.quantile result: {res:.7f}')
print(f'time: {dt.now() - t}')
t = dt.now()
res = lax.scan(lambda r, q: (r + binQuantile(x, q, method='linear'), None), 0., qs)[0]
print(f'binary search quantile result: {res:.7f}')
print(f'time: {dt.now() - t}')
t = dt.now()
res = lax.scan(lambda r, q: (r + binQuantile(x, q, method='fast', step=.0001), None), 0., qs)[0]
print(f'binary search quantile result (fast): {res:.7f}')
print(f'time: {dt.now() - t}')
@partial(jit, static_argnums=2)
def binQuantile(arr, q, method='linear', step=0.):
n = arr.shape[0] - 1
targetCases = n * (1. - q)
cond = lambda a: a[0] <= a[1]
def it(args):
low, high, mid = args
mid = (high + low) * .5
cases = (arr > mid).sum()
finished = (cases == targetCases) | (mid == low) | (mid == high)
low = jnp.select([cases > targetCases], [mid + step], low)
high = jnp.select([finished, cases < targetCases], [-jnp.inf, mid - step], high)
return low, high, mid
res = lax.while_loop(cond, it, (arr.min(), arr.max(), 0.) )[2]
def interpolate():
diff = arr - res
less = arr < res
lowerDiff = (diff + ~less * -jnp.inf).max()
higherDiff = (diff + less * jnp.inf).min()
lower = res + lowerDiff
higher = res + higherDiff
index = ['linear', 'lower', 'higher', 'midpoint', 'nearest', 'fast'].index(method)
return lax.switch(index, [
lambda: lower + (higher - lower) * ( (q * n) % 1.),
lambda: lower, lambda: higher,
lambda: (lower + higher) * .5,
lambda: jnp.select([higherDiff >= -lowerDiff], [lower], higher),
lambda: res
])
return lax.cond(method == 'fast', lambda: res, interpolate)
test() And this is the output (with GPU):
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Interesting! Before digging in further, I'd suggest taking a look at FAQ: Benchmarking JAX Code and make sure you're following the advice there (particularly with regard to asynchronous dispatch) to make sure you're measuring the actual runtime. I do see that x = random.normal(random.key( int( time() * 1e3) ), 1000)
print(jnp.quantile(x, 0.5))
print(binQuantile(x, 0.5))
%timeit jnp.quantile(x, 0.5).block_until_ready()
%timeit binQuantile(x, 0.5).block_until_ready()
Another thing to consider with iterative methods like this: under batching |
Beta Was this translation helpful? Give feedback.
Interesting! Before digging in further, I'd suggest taking a look at FAQ: Benchmarking JAX Code and make sure you're following the advice there (particularly with regard to asynchronous dispatch) to make sure you're measuring the actual runtime.
I do see that
binQuantile
is a few times faster thanjnp.quantile
for very large arrays, although the current method seems to outperform the proposed method for more moderately-sized arrays: