Poor performance of Matrix Decomposition with jax.lax.scan
#19718
-
This is a performance question, fundamentally, that is relevant to a scientific application I work on but I've recast it into a simpler, more accessible form for a performance study. Essentially, I need to do a matrix decomposition but it is a decomposition that isn't currently out of the box in numpy/scipy (LDL^T). Regardless of the origin, it's led me to think about how to optimize performance critical code in JAX. Here is the example: LU Decomposition. You can get to it from jax.scipy.linalg in one line, @jit
def LU_partial_pivoting(A):
"""LU_partial_pivoting(A, overwrite_a=False)
Factor a matrix A into it's LU decomposition with partial pivoting
"""
N = A.shape[0]
# Intial list of pivots (aka none, everything in order)
perms = numpy.arange(0, N)
L = numpy.zeros_like(A)
# Initial L matrix (all zeros):
carry = (
A,
L,
perms
)
carry, _ = lax.scan(lu_iteration, carry, perms)
U = carry[0]
L = carry[1]
# Create the permutation matrix:
P = numpy.zeros_like(A)
P = P.at[carry[2], perms].add(1.0)
return P, L, U With supporting functions here: def pivot(_A, _k, _kp):
'''
Perform a pivot operation on rows if needed
'''
temp = _A[_k]
_A = _A.at[_k].set(_A[_kp])
_A = _A.at[_kp].set(temp)
return _A
def no_pivot(_A, _k, _kp):
'''
Dummy function for the no-pivot case to make the conditional easy
'''
return _A
def form_gauss_vector(_A, _k):
"""
Form the gauss vector and update the matrix value as the return
"""
# Start with the values at every point in the target column
gauss_vector = _A[:,_k]
target_row = _A[_k,:]
a_nn = _A[_k,_k]
# Scale the gauss vector
gauss_vector = gauss_vector / a_nn
# This should be the biggest value at or below the diagonal
# in this column, from the pivoting:
# Set to 0 the entries above and including the diagonal:
mask = numpy.arange(gauss_vector.shape[0]) > _k
gauss_vector = numpy.where(mask, gauss_vector, 0.0)
# Update the matrix by subtracting off the gauss vector.
# use outer product shaping to get the whole thing:
_A = _A - numpy.outer(gauss_vector, target_row)
return _A, gauss_vector
pivot = jit(pivot, donate_argnums=0)
no_pivot = jit(no_pivot, donate_argnums=0)
form_gauss_vector = jit(form_gauss_vector, donate_argnums=0)
@jit
def lu_iteration(carry, _k):
# This needs to become a function we can iterate over:
# First, find the largest entry in A[k:,k] and
# permute its row to A[k,k]
# Keep track of this permutation too! It is the pivot matrix
# Doing this with a full region and we can use a statically-sized
# mask and a where operation to dynamically mask based on index.
# Surely it's faster to only look over the "real" region of interest
# but that will trigger a recompile every time which isn't ideal
# Pick out the current matrix from the carry object:
_A, _L = carry[0], carry[1]
# Current pivots
pivots = carry[2]
# First, decide if we are going to pivot:
full_region = _A[:,_k] # Take the whole column
# print("full_region: ", full_region)
# Only look at the area below the diagonal:
mask = numpy.arange(full_region.shape[0]) >= _k
full_region = numpy.where(mask, full_region, 0.0)
# print("fulll_region: ", full_region)
# Do we need to pivot?
_kp = numpy.abs(full_region).argmax()
# TODO: Switch the pivots here!
temp = pivots[_k]
pivots = pivots.at[_k].set(pivots[_kp])
pivots = pivots.at[_kp].set(temp)
# Apply a pivot if needed, and do both A and L:
_A = lax.cond(_kp != _k, pivot, no_pivot, _A, _k, _kp)
# print("Pivoted _A: ", _A)
_L = lax.cond(_kp != _k, pivot, no_pivot, _L, _k, _kp)
# get the update and form the gauss vectors if needed:
_A, gauss_vector = form_gauss_vector(_A, _k)
# Update _L:
_L = _L.at[:,_k].set(gauss_vector)
# Need to put ones on the diagonal of L:
_L = _L.at[_k,_k].set(1.0)
return (_A, _L, pivots), None
lu_iteration = jit(lu_iteration, donate_argnums=0) The full code, if you want to run it, is standalone at this script: Here is how the performance compares on a small nvidia GPU: For small matrices, it appears to be independent of matrix size and my custom implementation is about 15x slower. After about N=128, the starts growing with size as expected. But the gap opens up to 90x slower in the JAX/scan implementation, once the problem size is large enough that other overheads are small. I already know, if you look closely, that I leave a factor of 2 on the table here: Because dynamically shaped arrays would trigger re-JIT, I use static shapes for gaussian elimination and pivoting, even though that is moving a lot of 0s around for large n. But there doesn't really seem to be a better way to handle an algorithm that needs dynamic shapes other than 0 padding. Further, I can't really get around using some sort of loop. The gaussian elimination steps have to proceed in order, so parallelization is hard for an inherently sequential function. I'd love to understand the fundamental limitations here and make this closer to the optimized version. I don't pretend that this code would ever compete with vendor-optimized LU decomposition or the decades of CS research into writing those kernels well. But 90x slower seemed like a lot to me! Does anyone on this discussion board spot any obvious performance issues with the code? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hi - thanks for the question. The fundamental problem here is that I suspect the reason for the super-linear scaling with Perhaps this would be a case where pallas could be useful? (ping @sharadmv, who might have ideas there) |
Beta Was this translation helpful? Give feedback.
-
Arg! I had a follow up written and I lost it when I marked the answer. It boiled down to:
But mostly (1) thank you for your answer. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. The fundamental problem here is that
scan
has some unavoidable overhead on GPU (see e.g. #16106 (comment) for a general discussion) I'm not sure what to recommend beyond "don't use scan for this kind of operation on GPU".I suspect the reason for the super-linear scaling with
N
is that you end up operating on intermediate arrays of sizeO[N]
within each of theN
iterations, due to the static shaping requirement.Perhaps this would be a case where pallas could be useful? (ping @sharadmv, who might have ideas there)