-
Hello, and many thanks for this incredible framework. I’ve been working on rebuilding the core of a scientific software project to use Jax instead of PyTorch, in part to take advantage of the IMO more intuitive implicit differentiation. However, I’m having some difficulty figuring out the best way to vectorise a sparse-sparse matrix multiplication (e.g., To provide some context, I have 2 BCOO matrices that I’d like to multiply, each of which have 2 sparse dimensions and might have one or more dense dimensions. Some examples of situations where this kind of matrix structure arises:
The problem I’m having is that simply applying vmap(partial(
jax.experimental.sparse.bcoo_dot_general,
dimension_numbers=(((-3,), (-3,)), ((), ()))
), in_axes=(-1, -1)) But applying this to a sparse matrix results in a Since I wasn't able to get it working with import jax
import numpy as np
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
def random_sparse(key, shape, density=0.1):
"""
Generate a random sparse matrix.
"""
n = jnp.prod(jnp.array(shape))
nse = int(density * n)
k1, k2 = jax.random.split(key)
indices = jax.random.choice(k1, a=n, shape=(nse,), replace=False)
indices = jnp.stack(jnp.unravel_index(indices, shape), axis=-1)
data = jax.random.normal(k2, (nse,))
return BCOO((data, indices), shape=shape).sum_duplicates()
def to_batch(matrices):
"""
Convert a sequence of sparse matrices to a batch of matrices using the
batch-final, common-index COO format.
.. note::
This function is not intended to be compatible with JIT compilation.
"""
batch_size = len(matrices)
shape = sum([m.data.shape[0] for m in matrices])
remaining_shape = matrices[0].data.shape[1:]
indices = jnp.concatenate([m.indices for m in matrices], axis=0)
data = jnp.zeros((shape, *remaining_shape, batch_size))
start = 0
for i, matrix in enumerate(matrices):
end = start + matrix.data.shape[0]
data = data.at[start:end, ..., i].set(matrix.data)
start = end
return BCOO(
(data, indices),
shape=(*matrices[0].shape, batch_size)
).sum_duplicates()
def _get_dense_dim_mm(lhs, rhs):
"""
Get the dense dimension of the matrix multiplication.
"""
lhs_dims = lhs.data.shape[1:]
rhs_dims = rhs.data.shape[1:]
# we don't check for broadcastability here
return [max(l, r) for l, r in zip(lhs_dims, rhs_dims)]
def spspmm(lhs, rhs, inner_dims=(0, 0), outer_dims=(1, 1)):
"""
Sparse-sparse matrix multiplication with vectorisation over dense
dimensions.
"""
# only support 2D sparse for now
assert lhs.n_sparse == rhs.n_sparse == 2
dense_dim_out = _get_dense_dim_mm(lhs, rhs)
out_shape = (
lhs.shape[outer_dims[0]],
rhs.shape[outer_dims[1]],
*dense_dim_out
)
out_nse = lhs.nse * rhs.nse # memory use scales as product of NSEs
lhs_data = lhs.data[None, ...]
rhs_data = rhs.data[:, None, ...]
lhs_contract_dim, rhs_contract_dim = inner_dims
lhs_contract_idx = lhs.indices[:, lhs_contract_dim][None, :]
rhs_contract_idx = rhs.indices[:, rhs_contract_dim][:, None]
out_nonzero = (lhs_contract_idx == rhs_contract_idx)
extra_idx = [None] * len(dense_dim_out)
out_nonzero = out_nonzero[tuple([...] + extra_idx)]
out_data = jnp.where(out_nonzero, lhs_data * rhs_data, 0.)
lhs_indices = jnp.ones_like(lhs.indices).at[:, -2].set(
lhs.indices[:, outer_dims[0]])
rhs_indices = jnp.ones_like(rhs.indices).at[:, -1].set(
rhs.indices[:, outer_dims[1]])
out_indices = (lhs_indices[None, ...] * rhs_indices[:, None, ...])
out_indices = out_indices.reshape(out_nse, -1)
out_data = out_data.reshape(out_nse, *dense_dim_out)
return BCOO((out_data, out_indices), shape=out_shape)
shape_A = (10, 8)
shape_B = (10, 18)
batch_size = 5
density = 0.1
batch_A = [
random_sparse(
jax.random.PRNGKey(np.random.randint(2 ** 32)),
shape_A,
density=density)
for _ in range(batch_size)]
batch_B = [
random_sparse(
jax.random.PRNGKey(np.random.randint(2 ** 32)),
shape_B,
density=density)
for _ in range(batch_size)]
A = to_batch(batch_A)
B = to_batch(batch_B)
out = spspmm(A, B).todense()
ref = np.stack([(a.T @ b).todense() for a, b in zip(batch_A, batch_B)], axis=-1)
assert out.shape == (shape_A[1], shape_B[1], batch_size)
assert np.allclose(out, ref) But this naive implementation suffers from a critical flaw. Notably, the memory use scales as the product of the number of nonzero elements ( I think that what I’d like ideally is a way to simultaneously (i) "promise" the compiler exactly what indices should appear in the output, (ii) vectorise the matrix multiplication over the dense dimensions, and (iii) evaluate it only at the specified indices so as to save on memory (and do this all in a differentiable way). (I suppose I'm really asking more than one question here — sorry!) Is there a way to do this, or something like it, with Jax? I saw there was a Any help or leads would be appreciated — thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
From your example, it's hard for me to understand the exact goal of your question (e.g. "2 BCOO matrices... each of which have 2 sparse dimensions and might have one or more dense dimensions"). Are you saying the matrices are structured something like this?
And then you're hoping to apply If so, then I wonder if you could restructure your problem to use Side note: to this point, we haven't done much with the (trailing) dense dimensions, because they haven't seemed very useful in any practical application we've so far come across. |
Beta Was this translation helpful? Give feedback.
From your example, it's hard for me to understand the exact goal of your question (e.g. "2 BCOO matrices... each of which have 2 sparse dimensions and might have one or more dense dimensions"). Are you saying the matrices are structured something like this?
And then you're hoping to apply
vmap
over the last dimension?If so, then I wonder if you could restructure your problem to use
n_batch=1
rather thann_dense=1
: the reason for the existence of batch dimensions is to enable vmapping over sparse matrix operations. When you say "Batches of graphs …