Applying vmap to associative_scan for a linear recurrence does not appear to scale well #9856
-
I am attempting to vmap an associative_scan for a linear recurrence over a batch of input sequences. The computational time (using a Colab GPU) seems to scale almost linearly with the size of the batch that is vmapped over. I had hoped that on a GPU this would be able to parallelize the independent associative scans and scale much better. Here is a simplified example: import jax
import jax.numpy as np
from jax import jit, random
import matplotlib.pyplot as plt
import time
def scan_operator(ci, cj):
"""Operator to be used for scan and associative scan which solves a linear
recurrence with a diagonal transition matrix"""
def A_op(Ai, Aj):
return Ai * Aj
def b_op(Aj, bi, bj):
return Aj * bi + bj
return np.stack([A_op(ci[0], cj[0]), b_op(cj[0], ci[1], cj[1])])
parallel_scan_operator = jax.vmap(scan_operator)
def make_scan_elements(Ab, Bb, u):
"""Initialize elements for associative scan"""
A_elements = Ab * np.ones((u.shape[0],Ab.shape[0]))
b_elements = jax.vmap(lambda ui: Bb[:,0]*ui)(u)
return np.stack([A_elements, b_elements], axis=1)
def parallel_scan(elems_diag):
"""Perform associative scan"""
out = jax.lax.associative_scan(parallel_scan_operator, elems_diag)
return out[:,1]
@jit
def vmap_par_scan(elems):
"""vmap parallel scan"""
out = jax.vmap(parallel_scan)(elems)
return out
#Initialize parameters
N = 64
key = random.PRNGKey(0)
key, skey = random.split(key)
Lambda = np.ones(N)
B = random.normal(skey, shape=(N,1))
#Time associative scan for various batch sizes
bszs = np.arange(0, 600, 100)
bszs = bszs.at[0].set(1)
L = 1000
vmap_times = []
for bsz in bszs:
key, skey = jax.random.split(key)
us = jax.random.uniform(skey, (bsz,L))
elems = jax.vmap(make_scan_elements, in_axes=(None,None,0))(Lambda, B, us)
out = vmap_par_scan(elems).block_until_ready()
start_time = time.time()
out = vmap_par_scan(elems).block_until_ready()
elapsed_time = time.time() - start_time
vmap_times.append(elapsed_time)
plt.plot(bszs, vmap_times, 'o', label='Sequence_length={}'.format(L))
plt.xlabel('Sequence Lengths')
plt.ylabel('Times')
plt.legend()
print('batch_sizes:', bszs)
print('Times (seconds):',vmap_times) The print statement shows: As a sanity check, I also looked at how computational time scales when vmapping a regular scan for the same operator: #Make regular scan ops
def scan_op(carry, x):
carry = scan_operator(carry, x)
return carry, carry[1]
def scan_recurrence(xs):
init = np.stack([np.ones(N), np.zeros(N)])
_, outs = jax.lax.scan(scan_op, init, xs)
return outs
@jit
def vmap_scan(xs):
return jax.vmap(scan_recurrence)(xs)
#Time scan for various batch sizes
vmap_times = []
for bsz in bszs:
key, skey = jax.random.split(key)
us = jax.random.uniform(skey, (bsz,L))
elems = jax.vmap(make_scan_elements, in_axes=(None,None,0))(Lambda, B, us)
out = vmap_scan(elems).block_until_ready()
start_time = time.time()
out = vmap_scan(elems).block_until_ready()
elapsed_time = time.time() - start_time
vmap_times.append(elapsed_time)
plt.plot(bszs, vmap_times, 'o', label='Sequence_length={}'.format(L))
plt.xlabel('Sequence Lengths')
plt.ylabel('Times')
plt.legend()
print('batch_sizes:', bszs)
print('Times (seconds):',vmap_times) The print statements show: To be clear, the scan times are slower than the associative_scan times, as expected, but the scan times are staying relatively constant as the batch size grows. I was hoping for similar behavior with associative_scan. Is there anything wrong with how I am attempting to use associative_scan or that I could change to improve this? Alternatively, are there opportunities to improve this in the associative_scan implementation? Here is a link to a Colab notebook with the same code as above: https://colab.research.google.com/drive/1ZGLiSoS7ijs8BrytYQbIp52WZ0PM3ZND?usp=sharing |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Beta Was this translation helpful? Give feedback.
associative_scan
already parallel in the scan and maybe run out of GPU capacity sinceL=1000
(larger than batch size). I think you can measure the performance ofscan
andassociative_scan
with batch size 2000, and thescan
time is expected to catch upassociative_scan
and slower than batch size 500.