Skip to content
Discussion options

You must be logged in to vote

From your example, it's hard for me to understand the exact goal of your question (e.g. "2 BCOO matrices... each of which have 2 sparse dimensions and might have one or more dense dimensions"). Are you saying the matrices are structured something like this?

import numpy as np
from jax.experimental import sparse
mat_dense = np.random.rand((3, 4, 5))
mat = sparse.BCOO.fromdense(mat_dense, n_dense=1)

And then you're hoping to apply vmap over the last dimension?

If so, then I wonder if you could restructure your problem to use n_batch=1 rather than n_dense=1: the reason for the existence of batch dimensions is to enable vmapping over sparse matrix operations. When you say "Batches of graphs …

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@rciric
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@rciric
Comment options

Answer selected by rciric
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants