Updating the element of the DeviceArray #12692
Answered
by
soraros
yiminghwang
asked this question in
Q&A
-
Hi, there, I am going to update the element of the DeviceArray, but it runs much slower than Ndarry of NumPy, is there any way to improve the performance? m,n = 100,100
def func1():
A = jnp.zeros((m,n))
start = time.time()
for i in range(m):
for j in range(n):
A.at[i,j].set(i*j)
end = time.time()
return end-start
def func2():
B = np.zeros((m,n))
start = time.time()
for i in range(m):
for j in range(n):
B[i,j]=i*j
end = time.time()
return end-start
The above is just a simple test. |
Beta Was this translation helpful? Give feedback.
Answered by
soraros
Oct 7, 2022
Replies: 1 comment
-
Maybe you could share more information about your real™ use case? There are some (hopefully useful) general suggestions I can give now:
@partial(jit, static_argnums=(0,))
def f(shape: tuple[int, int]) -> tuple[Array, Array]:
I, J = jnp.indices(shape)
A = jnp.zeros(shape).at[I, J].set(I * J)
return A or even simpler for your toy example A = jnp.fromfunction(jnp.multiply, (m, n))
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
jakevdp
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Maybe you could share more information about your real™ use case?
There are some (hopefully useful) general suggestions I can give now:
.at[...].set(...)
is functional, meaning you should writeA = A.at[i, j].set(i * j)
or even simpler for your toy example