-
This is the testing code for comparing the reshape function in Numpy and Jax. It reshapes 2
The Jax code runs on a V100 card with 30G memory. And the Numpy code runs on Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz. I'm curious why this happens and how to solve it?
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 10 replies
-
Possible causes:
from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.jit, donate_argnums=0) # use donate argnums to enable inplace update
def modify(a):
return a.at[0].set(10)
x = jnp.arange(4)
y = x.reshape((2, 2))
modify(x)
print(y) # unmodified import numpy as np
x = np.arange(4)
y = x.reshape((2, 2))
x[0] = 10
print(y) # modified |
Beta Was this translation helpful? Give feedback.
-
Hello - this is entirely expected behavior for microbenchmarks of small, individual operations like Also, when writing microbenchmarks like this with JAX, you need to be careful about effects of asynchronous dispatch, data transfer, and other issues. For more information on how to write good microbenchmarks, see FAQ: Benchmarking JAX Code. |
Beta Was this translation helpful? Give feedback.
-
@jakevdp Why buffer donation requires same shape instead of same size? from functools import partial
from contextlib import contextmanager
import jax
@contextmanager
def timer(msg):
from time import time
t = time()
yield
print(f'{msg}: {time() - t}')
@partial(jax.jit, donate_argnums=0, static_argnums=1)
def reshape(a, ord):
return jax.numpy.reshape(a, ord)
d = 100
v = jax.numpy.zeros((d * d, d * d))
for i in range(3):
print(f'---- step {i + 1} -----')
with timer('reshape time:'):
v = reshape(v, (d, d, d, d))
v.block_until_ready()
v = reshape(v, (d * d, d * d))
v.block_until_ready() And I got warning:
|
Beta Was this translation helpful? Give feedback.
Hello - this is entirely expected behavior for microbenchmarks of small, individual operations like
reshape
, though it will not generally result in slower execution when thereshape
is used in more realistic situations (i.e. a sequence of JIT-compiled operations). For information on why this may be expected, see FAQ: Is JAX Faster Than NumPy?.Also, when writing microbenchmarks like this with JAX, you need to be careful about effects of asynchronous dispatch, data transfer, and other issues. For more information on how to write good microbenchmarks, see FAQ: Benchmarking JAX Code.