Fastest way to "transpose" list of jax arrays on gpu #9848
-
Hi all, I want to convert a list of lists of jax arrays into a list of jax arrays where the first two dimensions are swapped, i.e. something like Here are the details: import time
import numpy as np
import jax
import jax.numpy as jnp
D = 8
tensor_list = [jnp.ones((2**i,2,2**(i+1))) for i in range(D)] + [jnp.ones((2**(i+1),2,2**i)) for i in reversed(range(D))]
# this is the list I start from (note that each jax array in this list of lists is a ndim=3 tensor, however with different shapes)
batch_list = [tensor_list] * batch_size
# this is the list I want to end up with
final_list = [jnp.ones((batch_size, t.shape[0], t.shape[1], t.shape[2]), dtype=t.dtype) for t in batch_list[0]]
# original approach which is fast on cpu, but slow on gpu because presumably I copy from gpu (jax array) to cpu (numpy array)
def version_1(batch_list, final_list):
for i,t_list in enumerate(batch_list):
for j,t in enumerate(t_list):
final_list[j][i] = t
# equivalent to:
# final_list[j][i] = batch_list[i][j]
return final_list
# what I would naively/intuitively do, but it's still quite slow on the gpu
@jit
def version_2(batch_list, final_list):
for i,t_list in enumerate(batch_list):
for j,t in enumerate(t_list):
final_list[j] = final_list[j].at[i, :].set(t)
return final_list
# fastest on the gpu, but I don't think it's pretty
@jit
def version_3(batch_list, final_list):
for i,t_list in enumerate(batch_list):
for j,t in enumerate(t_list):
final_list[j][i] = t
for j,t in enumerate(final_list):
final_list[j] = jnp.array(t)
return final_list When comparing the individual approaches I get the following times: final_list = [np.empty((batch_size, t.shape[0], t.shape[1], t.shape[2]), dtype=t.dtype) for t in batch_list[0]]
start_time = time.time()
final_list = version_1(batch_list, final_list)
print(f"Version 1: {time.time()-start_time}")
final_list = [jnp.empty((batch_size, t.shape[0], t.shape[1], t.shape[2]), dtype=t.dtype) for t in batch_list[0]]
final_list1 = version_2(batch_list, final_list) # for compiling
start_time = time.time()
final_list1 = version_2(batch_list, final_list)
final_list1[0].block_until_ready()
print("Version 2: {time.time()-start_time}")
final_list = [[jnp.empty((t.shape[0], t.shape[1], t.shape[2]), dtype=t.dtype)]*batch_size for t in batch_list[0]]
final_list1 = version_3(batch_list, final_list) # for compiling
start_time = time.time()
final_list1 = version_3(batch_list, final_list)
final_list1[0].block_until_ready()
print("Version 3: {time.time()-start_time}") returns Version 1: 0.005630 (on cpu)
Version 1: 0.039899 (on gpu)
Version 2: 0.008874 (on gpu)
Version 3: 0.002829 (on gpu) Is there something better/faster I can do than version 3? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 6 replies
-
May you clarify what exactly you want to do? assert (
len(x) == B and all(len(x[i]) == 2 * D for i in range(B)) and all(x[i][j].ndim == 3 for i in range(B) for j in range(2 * D))
and all(x[i][k].shape == x[j][k].shape for i in range(B) for j in range(B) for k in range(2 * D))
) to assert len(y) == 2 * D and all(y[i].ndim == 4 and y[i].shape[0] == B for i in range(2 * D)) And do you need to write into an exist import jax
import jax.numpy as jnp
B = 32
D = 8
x = [[*(jnp.ones((2**i,2,2**(i+1))) for i in range(D)), *(jnp.ones((2**(i+1),2,2**i)) for i in reversed(range(D)))] for _ in range(B)]
@jax.jit
def transpose(x: list[list[jnp.ndarray]]) -> list[jnp.ndarray]:
z = [t for t in zip(*x)] # list[tuple[jnp.ndarray]]
assert len(z) == 2 * D and len(z[0]) == B
return [jnp.stack(t) for t in z]
@jax.jit
def transpose_simplified(x: list[list[jnp.ndarray]]) -> list[jnp.ndarray]:
return [jnp.stack(t) for t in zip(*x)] In one line, WDYT? final_list = [[None for _ in range(B)] for _ in range(2 * D)]
@jax.jit
def version_3(batch_list, final_list):
for i,t_list in enumerate(batch_list):
for j,t in enumerate(t_list):
final_list[j][i] = t
for j,t in enumerate(final_list):
final_list[j] = jnp.array(t)
return final_list And result is (on V100 GPU): def timer(f: Callable[[], Any]):
from time import time
f() # warmup
t = time()
for _ in range(5000):
f()
print((time() - t) / 5000)
timer(lambda: jax.block_until_ready(transpose(x)))
timer(lambda: jax.block_until_ready(transpose_simplified(x)))
timer(lambda: jax.block_until_ready(version_3(x, final_list))) |
Beta Was this translation helpful? Give feedback.
-
Thanks so much YouJiacheng! I have a quick follow-up question: Imagine I have two of these lists now (x1, x2) which I want transpose separately. They have equal lengths/shapes: assert (
len(x1) == len(x2) and len(x1[0]) == len(x2[0])
and all(x1[i][k].shape == x2[i][k].shape for i in range(len(x1)) for k in range(len(x1[0])))
) Can I do better than calling Thanks again! |
Beta Was this translation helpful? Give feedback.
-
@frmetz @jax.jit
def stack(x: list):
return jax.tree_map(lambda *xs: jnp.stack(xs), *x) Thus, for you follow-up question, if you can easily y1, y2 = stack(list(zip(x1, x2))) BTW, I think you should change your data pattern: directly produce tree of arrays with batch axis, since compiling of function with a large list input will be painfully slow. (batch size 256 takes 11 seconds to compile on my device) For example: def generate_one_data(z: float):
return [z * jnp.ones((2 ** i, 2)) for i in range(D)]
generate_batch_data = jax.vmap(generate_one_data)
xs = generate_batch_data(jnp.ones((B,)))
assert len(xs) == D and all(xs[i].shape == (B, 2 ** i, 2) for i in range(D)) def generate_one_data_with_dependency(carry: float, z: float):
return carry + z, [carry * jnp.ones((2 ** i, 2)) for i in range(D)]
def generate_batch_data_with_dependency(init: float, zs):
return jax.lax.scan(generate_one_data_with_dependency, init, zs)[1]
ys = generate_batch_data_with_dependency(0.0, jnp.ones((B,)))
assert len(ys) == D and all(ys[i].shape == (B, 2 ** i, 2) for i in range(D))
assert all(y[i][0][0] == i for y in ys for i in range(B)) |
Beta Was this translation helpful? Give feedback.
May you clarify what exactly you want to do?
Do you want to convert a
x: list[list[ndarray]]
withto
y: list[ndarry]
withAnd do you need to write into an exist
y
, or just need construct such an array?Given such specification, I think it is natural to write a function similar to your version 3:
First transpose the nested list in python without manipulate array, which …