Why is JAX so fast? #11078
-
I was writing a notebook to illustrate the performance trade-offs of array-oriented programming and imperative programming, which calculates the Mandelbrot set in a variety of different ways:
The reason I chose the Mandelbrot set is because each pixel can be calculated independently, but the algorithm that determines the value of each pixel must "iterate until converged," which can be a different number of times for different pixels. Then it occurred to me to add JAX, because it's an independent basis vector in this space:
The details are all in the linked notebook (above), but here's the bottom line: I had a pretty good story going until JAX was added to the mix. That story was:
So we basically have 3 levels: the best a GPU can do, the best a CPU can do, and Python. CuPy and NumPy with temporary arrays are somewhat worse than the best a GPU or a CPU can do, respectively. Nice story. But then I added JAX, and its final CPU speed is almost 5× better than all of the compiled CPU variants (and those are pretty close to each other), and its final GPU speed is about 6× better than CuPy and Numba-CUDA. How can that be? At first, I noticed that the CPU version was using all my cores, whereas the other tests were single-threaded. There isn't a good way to turn that off, so I forced the whole JupyterLab process to have affinity with CPU 0: taskset -c 0 jupyter lab Not only did I see in assert len(jax.devices("cpu")) == 1 in the code to make sure it's not accidentally run on all cores in the future. Then I noticed that JAX asynchronously dispatches its tasks, so I put a It used to be the case that the maximum number of iterations in the algorithm was variable, but JAX won't compile with that (since it's a tracer). So I made all the implementations use a compile-time constant maximum number of iterations, just to be fair. Thinking that JAX was taking advantage of that to unroll the loop over iterations, I expected C++ ( What gives? Is my test still unfair because of something I didn't think of? Or if JAX really is much faster for this kind of algorithm, why? What is XLA doing differently from gcc, LLVM, nvrtc, and NVVM that makes it so much faster? Below are some details, in case the answer is hidden in them. `/proc/cpuinfo` for CPU 0 (the one I pinned the whole JupyterLab process to)
`sudo lshw -C display` for the GPU (and also the built-in graphics)
`nvidia-smi -a` for the GPU
LLVM for `numba_inner_loop`, which has a low-level function and also a much larger unboxing/boxing function
x86 assembly for `numba_inner_loop`, which has a low-level function and also a much larger unboxing/boxing function
LLVM for `one_pixel_numba_cuda`, which is just the low-level function
PTX for `one_pixel_numba_cuda`, which is just the low-level function
On the plus side, if there's a good reason why JAX is much faster for this type of algorithm, my Mandelbrot study becomes a great advertisement for JAX, XLA, or both! (Especially since the example was not hand-selected for it.) |
Beta Was this translation helpful? Give feedback.
Replies: 14 comments 42 replies
-
To force jax to use a single CPU you can try to set this flag: XLA_FLAGS="--xla_force_host_platform_device_count=1" and do a |
Beta Was this translation helpful? Give feedback.
-
IIUC, JAX tracers have static shape, thus |
Beta Was this translation helpful? Give feedback.
-
BTW, for "CuPy with a custom kernel", it seems that warp divergence also prevent global memory access coalescing? |
Beta Was this translation helpful? Give feedback.
-
One important thing: JAX use complex64(float32 for each part) instead of complex128 by default, did you take care of it? |
Beta Was this translation helpful? Give feedback.
-
Summarizing the above, my current to-do list of things to check:
If these things explain the difference, it still comes out as a win for the JAX/XLA approach because these optimizations are hard to do manually. You'll have won a convert, and I'm trying to see how my project might fit into a world in which array bounds are all known at compile-time. (For my project in full generality, it would be hard.) |
Beta Was this translation helpful? Give feedback.
-
Well, to exclude the time of device-to-host copy, I tried: import cupy as cp
from cupyx.profiler import benchmark
h, w = 2048, 4096
fractal = cp.empty((h, w), dtype=cp.int32)
cupy_custom_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>
extern "C" __global__
void cupy_custom_kernel(int height, int width, int* fractal) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
complex<float> j(0.0, 1.0);
complex<float> z, c;
z = c = complex<float>(-1.5 + y*1.0/(height + 1)) - j + complex<float>(x*1.5)*j/complex<float>(width + 1);
fractal[y + x * width] = 20;
for (int i = 0; i < 20; i++) {
z = z * z + c;
if (z.real() * z.real() + z.imag() * z.imag() > 4) {
fractal[y + x * width] = i;
break;
}
}
}
''', "cupy_custom_kernel", options=("--use_fast_math",))
cupy_custom_kernel.compile()
def run_cupy_custom_kernel(height, width, fractal):
griddim = (height // 32, width // 32)
blockdim = (32, 32)
cupy_custom_kernel(griddim, blockdim, (height, width, fractal))
return fractal
print(benchmark(run_cupy_custom_kernel, (h, w, fractal), n_repeat=20))
cupy_custom_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>
extern "C" __global__
void cupy_custom_kernel(int height, int width, int* fractal) {
int y = blockIdx.x * blockDim.x + threadIdx.x;
int x = blockIdx.y * blockDim.y + threadIdx.y;
complex<float> j(0.0, 1.0);
complex<float> z, c;
z = c = complex<float>(-1.5 + y * 1.0/(height + 1)) - j + complex<float>(x * 1.5) * j / complex<float>(width + 1);
int r = 20;
for (int i = 0; i < 20; i++) {
z = z * z + c;
if (z.real() * z.real() + z.imag() * z.imag() > 4) {
r = i;
break;
}
}
fractal[x * width + y] = r;
}
''', "cupy_custom_kernel", options=("--use_fast_math",))
cupy_custom_kernel.compile()
def run_cupy_custom_kernel(height, width, fractal):
griddim = (height // 32, width // 32)[::-1]
blockdim = (32, 32)
cupy_custom_kernel(griddim, blockdim, (height, width, fractal))
return fractal
print(benchmark(run_cupy_custom_kernel, (h, w, fractal), n_repeat=20)) EDIT: fix
|
Beta Was this translation helpful? Give feedback.
-
It seems that JAX is slower: import jax
def run_jax_kernel(fractal):
h, w = fractal.shape
y, x = jax.numpy.ogrid[-1:0:h*1j, -1.5:0:w*1j]
z = c = x + y * 1j
for i in range(20):
z = z * z + c
diverged = z.real * z.real + z.imag * z.imag > 4 # EDIT: fixed from z.read * z.imag > 4
diverging_now = diverged & (fractal == 20)
fractal = jax.numpy.where(diverging_now, i, fractal)
return fractal
run_jax_gpu_kernel = jax.jit(run_jax_kernel)
def run_jax_gpu(fractal):
run_jax_gpu_kernel(fractal).block_until_ready()
fractal = jax.numpy.full((2048, 4096), 20, dtype=jax.numpy.int32)
run_jax_gpu(fractal)
from time import perf_counter_ns
t = perf_counter_ns()
for _ in range(20):
run_jax_gpu(fractal)
print((perf_counter_ns() - t) / 20 / 1e3) # ~410us # EDIT |
Beta Was this translation helpful? Give feedback.
-
It's resolved; thanks for all your help, @YouJiacheng! The main thing that I was missing in my CPU implementations was that I was using I'm still missing something in the CPU implementation, because all of mine (pybind11, Cython, Numba, Numba) are consistently 40% worse than The main thing that I was missing in my GPU implementations was the use of The final plot for my computer (specs as given above) is You can see that the GPU implementations are all about the same as one another, the JAX CPU implementation is about 40% better than the other four, NumPy and CuPy without controlling for intermediate arrays is much worse than careful compilation, but much better than nothing. Python is on the far right, representing "nothing." The notebook runs as-is on Google Colab because that environment happens to have only one CPU (at my usage tier, anyway). Here's what the plot looks like when run there: Pretty similar, though pybind11 and Cython are worse. Maybe the C++ compiler is different or maybe external calls to binaries are different for some reason? Anyway, I didn't enter into this project expecting to be as impressed with I've also updated the gist. |
Beta Was this translation helpful? Give feedback.
-
Chiming in here (this is a fascinating thread) to comment that JAX/XLA uses some but not all fastmath optimisations by default. If you're using all of fastmath for the other approaches then for fairness' sake JAX/XLA should be allowed to as well. No idea if that'll actually effect your end result, of course. I mean JAX is winning anyway, but obviously I'm curious to know if the margin can be improved. |
Beta Was this translation helpful? Give feedback.
-
Julia benchmark vs. Jax vs. Numbatldr:
follwoing discussion on Julia discourse (thanks to @chriselrod), and after realizing Jax unrolls the # NOT using fastmath
function run_julia(height, width)
y = range(-1.0f0, 0.0f0; length = height) # need Float32 because Jax defaults to it
x = range(-1.5f0, 0.0f0; length = width)
c = x' .+ y*im
fractal = fill(Int32(20), height, width)
# this checks if indicies are compatible between `c` and `fractal`
@inbounds for idx in eachindex(c, fractal)
_c = c[idx]
z = _c
m = true
Base.Cartesian.@nexprs 20 i -> begin
z = z^2 + _c
az4 = abs2(z) > 4f0
fractal[idx] = ifelse(m&az4, Int32(i), fractal[idx]) # 32-bit Int, same reason as above
m &= (!az4)
end
end
return fractal
end JaxIn [8]: %%timeit -o
...: run_jax_cpu(2000, 3000)
97.9 ms ± 870 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Out[8]: <TimeitResult : 97.9 ms ± 870 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)> Julia (with unroll & without fastmath)julia> @benchmark run_julia(2000,3000)
BenchmarkTools.Trial: 93 samples with 1 evaluation.
Range (min … max): 48.687 ms … 114.762 ms ┊ GC (min … max): 0.31% … 54.80%
Time (median): 52.597 ms ┊ GC (median): 0.89%
Time (mean ± σ): 53.969 ms ± 8.555 ms ┊ GC (mean ± σ): 3.34% ± 7.73%
▇█
▄▁▁██▄▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁
48.7 ms Histogram: log(frequency) by time 107 ms <
Memory estimate: 68.66 MiB, allocs estimate: 4. Additional comparison with NumbaNumbaIn [6]: %%timeit -o
...: run_numba(2000, 3000)
143 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Out[6]: <TimeitResult : 143 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)> Julia (without unroll & with fastmath)@inline function fast2(x) # pending PR to add this as the fastmath routine for complex^2
r = real(x)
i = imag(x)
Complex(fma(r, r, -i * i), fma(r, i, i * r))
end
function run_julia(height, width)
y = range(-1.0f0, 0.0f0; length = height)
x = range(-1.5f0, 0.0f0; length = width)
c = x' .+ y*im
fractal = fill(Int32(20), height, width)
@inbounds @fastmath for idx in eachindex(c)
_c = c[idx]
z = _c
for i = 1:20
z = fast2(z) + _c
if abs2(z) > 4f0
fractal[idx] = i
break
end
end
end
return fractal
end
julia> @benchmark run_julia(2000,3000)
BenchmarkTools.Trial: 46 samples with 1 evaluation.
Range (min … max): 105.716 ms … 112.512 ms ┊ GC (min … max): 0.00% … 0.72%
Time (median): 110.110 ms ┊ GC (median): 0.00%
Time (mean ± σ): 109.905 ms ± 1.457 ms ┊ GC (mean ± σ): 0.38% ± 0.52%
▁ ▄ ▁▁▁▁▄ ▁ ▄ █
▆▁▁▁▁▁▁▁▁▁▆▁▁▆▆▆▁▁▁█▁▁▁▁▁▁▁▁▆▆▆█▁▆▆▆█████▆█▆█▁▆▁█▁▆▆▁▆▆▆▆▁▁▁▆ ▁
106 ms Histogram: frequency by time 113 ms <
Memory estimate: 68.66 MiB, allocs estimate: 4. Platform information and Image output for debugging / sanity checkjulia> versioninfo(verbose=true)
Julia Version 1.8.0-beta3
Commit 3e092a2521 (2022-03-29 15:42 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
uname: Linux 5.18.6-arch1-1 #1 SMP PREEMPT_DYNAMIC Wed, 22 Jun 2022 18:10:56 +0000 x86_64 unknown
CPU: 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz:
speed user nice sys idle irq
#1 4200 MHz 47711 s 0 s 15165 s 1050914 s 18726 s
#2 400 MHz 57317 s 0 s 15044 s 25843 s 1600 s
#3 2800 MHz 58279 s 0 s 14516 s 26112 s 1707 s
#4 1300 MHz 58318 s 0 s 14387 s 25701 s 1746 s
#5 4200 MHz 54057 s 0 s 13005 s 26334 s 1868 s
#6 1962 MHz 57398 s 1 s 15431 s 26493 s 1450 s
#7 400 MHz 58745 s 0 s 14183 s 26216 s 1598 s
#8 400 MHz 58657 s 1 s 14181 s 25788 s 1460 s
Memory: 31.14313507080078 GB (17520.08984375 MB free)
Uptime: 279087.59 sec
Load Avg: 1.4 1.22 0.99
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, tigerlake)
Threads: 1 on 8 virtual cores
julia> using Plots
julia> heatmap(run_julia(2000, 3000)) ![]() |
Beta Was this translation helpful? Give feedback.
-
Any chance we can get a Dex benchmark @axch or @apaszke ? Maybe with the new vectorization pass? (or relying on LLVM). The levenshtein distance example is really cool |
Beta Was this translation helpful? Give feedback.
-
Julia GPU (CPU) Kernel programming
You can write a vendor-agnostic kernel of this algorithm with KernelAbstractions.jl and it would run on NVIDIA, AMD, Intel GPUs alike (pending oneAPI PR), I will demonstrate the point with a kernel instantiated on CUDA device: using KernelAbstractions, CUDAKernels, CUDA
@kernel function julia_kernel!(c, fractal)
I = @index(Global)
_c = c[I]
z = _c
@inbounds for i = 1:20
z = z^2 + _c
if abs2(z) > 4f0
fractal[I] = Int32(i)
break
end
end
end
function run_julia_gpu(height, width)
y = CuArray(range(-1.0f0, 0.0f0; length = height))
x = CuArray(range(-1.5f0, 0.0f0; length = width))
c = x' .+ y*im
fractal = CUDA.fill(Int32(20), height, width)
# the 32^2 is comparible to gridsize in Numba-CUDA
kernel! = julia_kernel!(CUDADevice(), 32^2) # we instantiate a kernel with vendor info
kernel!(c, fractal; ndrange=size(c))
return Array(fractal) # copy back to CPU, blocking
end
julia> @benchmark run_julia_gpu(2000,3000)
BenchmarkTools.Trial: 1279 samples with 1 evaluation.
Range (min … max): 2.349 ms … 12.656 ms ┊ GC (min … max): 0.00% … 11.64%
Time (median): 2.419 ms ┊ GC (median): 0.00%
Time (mean ± σ): 3.908 ms ± 3.135 ms ┊ GC (mean ± σ): 7.66% ± 14.73%
█ ▂ ▁▂ ▂
█▄▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▄█▇▆▅▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██▆▁▁▅██ ▇
2.35 ms Histogram: log(frequency) by time 11.9 ms <
Memory estimate: 22.91 MiB, allocs estimate: 130. Use CPU as (bad) kernal executorCPU is a very bad GPU (for this kind of work), but just to show that the kernel function is indeed vendor-agnostic, without changing the kernel function definition: function run_julia_cpu_jaxstype(height, width)
y = range(-1.0f0, 0.0f0; length = height)
x = range(-1.5f0, 0.0f0; length = width)
c = x' .+ y*im
fractal = fill(Int32(20), height, width)
kernel! = julia_kernel!(CPU(), length(c)÷Threads.nthreads()) # we're using 1-thread here
event = kernel!(c, fractal; ndrange=length(c))
wait(event) # not copying back, need to block here
return fractal
end
julia> @benchmark run_julia_gpu(2000,3000)
BenchmarkTools.Trial: 28 samples with 1 evaluation.
Range (min … max): 176.828 ms … 185.904 ms ┊ GC (min … max): 0.00% … 0.72%
Time (median): 178.417 ms ┊ GC (median): 0.49%
Time (mean ± σ): 178.964 ms ± 2.117 ms ┊ GC (mean ± σ): 0.41% ± 0.39%
▃ ▃▃ █ █ █ ▃ ▃
█▁▇▁▇▇██▇█▇█▇▁▇▁█▁▁█▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇ ▁
177 ms Histogram: frequency by time 186 ms <
Memory estimate: 68.67 MiB, allocs estimate: 21. CUDA Device Info
|
Beta Was this translation helpful? Give feedback.
-
I actually have the opposite issue - why is jax so slow compared to cupy ? For matrix multiplication, cupy is 5x faster on my machine while times are comparable for SVD
|
Beta Was this translation helpful? Give feedback.
-
More information from @mjbaldwin: there's a compilation time/runtime trade-off in the number of Mandelbrot iterations, which is fixed at 20 in this example, but might be much larger in a similar problem. |
Beta Was this translation helpful? Give feedback.
Well, to exclude the time of device-to-host copy, I tried: