Questions about pmap and xla_force_host_platform_device_count #14762
Unanswered
dmmdanielg
asked this question in
Q&A
Replies: 1 comment
-
i would assume that normally jax CPU is not multi-processed so only one CPU core is used at a time. While with The reason why |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
For context I am currently evaluating JAX for signal processing/numerical computing applications.
1. How does pmap interact with xla_force_host_platform_device_count on CPU and GPUs?
Code: If you want to verify results
https://github.com/dmmdanielg/ImageProcessingColab/blob/main/ImageProcessingGitHub.ipynb
I don't know if this is intended behavior but it seems like a bug or limitation of vmap to use all available resources. The performance tab of the Windows Task Manager shows almost 100% CPU usage for all threads with pmap. While vmap only shows 30% with 1 threads at 100% (for some reason have trouble uploading image of task manager but can show later if needed)
Results (collected with Juypter Notebook/Google Colab on a fresh restart /run all)
process_image_batch_jax //manual vectorization of jit function
process_image_batch //manual vectorization of numpy
process_image_jax_vmap_jit //vmap with jit
process_image_jax_vmap //vmap only
process_image_jax_pmap //pmap
Google Colab Results (xla_force_host_platform_device_count =10)
957 ms ± 274 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
1.12 s ± 49.4 ms per loop (mean ± std. dev. of 3 runs, 7 loops each)
670 ms ± 41.7 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
707 ms ± 56.6 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
544 ms ± 103 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
Google Colab Results (xla_force_host_platform_device_count =2)
964 ms ± 350 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
1.12 s ± 95.2 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
671 ms ± 34.9 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
646 ms ± 47.6 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
543 ms ± 51 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
Local VM (xla_force_host_platform_device_count =10)
264 ms ± 12.2 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
430 ms ± 10.3 ms per loop (mean ± std. dev. of 3 runs, 7 loops each)
252 ms ± 16 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
250 ms ± 15.4 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
84.6 ms ± 13.3 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
Local VM (xla_force_host_platform_device_count =2)
263 ms ± 12.9 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
453 ms ± 26.2 ms per loop (mean ± std. dev. of 3 runs, 7 loops each)
252 ms ± 19.4 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
253 ms ± 16.7 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
153 ms ± 8.51 ms per loop (mean ± std. dev. of 5 runs, 7 loops each)
We can see with the local VM the performance increases significantly and scales with xla_force_host_platform_device_count as it has 10 threads while Google Colab only has 2.
Let me know if you need clarification. My questions may not be clear but I noticed some unusual behavior with pmap and xla_force_host_platform_device_count and I'm wondering if any of this behavior is expected.
Beta Was this translation helpful? Give feedback.
All reactions