Slow array programming on GPU compared to CPU #21617
-
Hi all, I am writing a PDE solver in JAX and having unsatisfactory performance on GPU. It turned out that one of the problems boils down to the example below. GPU was 30 times slower than CPU for the element wise multiplication of two vectors. Is this because of the data transfer between the host (CPU) and the device (GPU)? Let me know if any of you knows a better way to implement it.
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think you have the wrong mental model of a GPU. You shouldn't think of a GPU as a faster CPU. You should think of a GPU as a whole bunch of really slow CPUs with fast communication and shared memory, which can work together to do large computations in parallel, thereby beating a typical CPU that doesn't have access to such parallelism. When you do a small computation (like your 100 element-wise multiplications), your problem is not really in a regime where you can benefit from the inherent parallelism of the GPU, and so the CPU will outperform it. On larger problems, you should find that the GPU will out-perform the CPU: for example, if I change your code from |
Beta Was this translation helpful? Give feedback.
I think you have the wrong mental model of a GPU. You shouldn't think of a GPU as a faster CPU. You should think of a GPU as a whole bunch of really slow CPUs with fast communication and shared memory, which can work together to do large computations in parallel, thereby beating a typical CPU that doesn't have access to such parallelism.
When you do a small computation (like your 100 element-wise multiplications), your problem is not really in a regime where you can benefit from the inherent parallelism of the GPU, and so the CPU will outperform it. On larger problems, you should find that the GPU will out-perform the CPU: for example, if I change your code from
N = 101
toN = 10000001
, I…