Why does jitted GPU code require 100% of CPU core? #16348
Replies: 3 comments
-
JIT tracing & compilation happens on the host CPU, while execution of the compiled program happens on the GPU. Since you mentioned "looping a jax.numpy eigenvalue calculation", I suspect your computation may be dominated by compilation (on CPU) rather than on execution (on GPU), and so it would make sense that the computation would utilize CPU for the majority of the time. It's hard to say more without seeing the code you're running! |
Beta Was this translation helpful? Give feedback.
-
I mean that the computation while running on the GPU also keeps a CPU core 100% occupied. I'm not sure why this would employ the CPU at all after tracing, but I'm also just trying to get familiar with JAX. If you wouldn't mind taking a look, example code is as follows: import jax
from jax import jit
import jax.numpy as jnp
from jax.lax import fori_loop
from jax import random
with jax.default_device(jax.devices('gpu')[0]):
@jit
def test_linalg():
n = 1000
def body_fun(i, seed):
key = random.PRNGKey(seed)
# generate random symmetric matrix
el = random.normal(key, (n*(n-1)//2,))
X = jnp.zeros((n,n))
X = X.at[jnp.triu_indices_from(X, 1)].set(el)
X = X + X.T
# use first eigenvalue to generate new key
output = jnp.linalg.eigh(X)
return (output[0][0] * 100).astype(int)
return fori_loop(0, 1000, body_fun, 0)
print(test_linalg()) Thanks. |
Beta Was this translation helpful? Give feedback.
-
Two possibilities:
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm finding that my jitted code is still occupying 100% of the invoking CPU process even though there should be nothing being transferred to the invoking workspace until the end of the computation. I've tested this on a simple example of a looping jax.numpy eigenvalue calculation.
Is this expected, or am I doing something strange?
And if yes, why would this be the case?
Thanks for any hints.
Beta Was this translation helpful? Give feedback.
All reactions