-
I'll start by acknowledging that this is part of I'll also acknowledge that this was my first attempt at playing with Here's what I'm seeing that came as a surprise to me: I was expecting that the sparse matrix squaring would return the same result. Here's the code to reproduce: from jax.experimental import sparse
import jax.numpy as jnp
import numpy as np
from jax import grad, jit
a = jnp.array([[0., 1., 0.],
[3., 0., 0.],
[0., 0., 4.]])
u = jnp.array([[0., 4.],[0., 2.],[0., 0.],])
v = jnp.array([[2., 1.],[2., 4.],[0., 5.],])
def s_mult_proto(x,y):
return jnp.multiply(x,y)
s_mult = sparse.sparsify(s_mult_proto)
m = a - u @ v.T
print(m)
s_mult_proto(m,m)
m = (sparse.BCOO.fromdense(a) - (sparse.BCOO.fromdense(u) @ sparse.BCOO.fromdense(v).T))
print(m.todense())
s_mult(m,m).todense() |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 8 replies
-
Thanks for the report - this looks like a bug with the newly-added sparse lowering of For the moment, I think you can work around it by doing
after which I've opened #8889 to track the bug. Thanks! |
Beta Was this translation helpful? Give feedback.
-
Hm, I may not understand the workaround:
Yields:
|
Beta Was this translation helpful? Give feedback.
-
One followup I have is that I'm unsure how I'd integrate the suggested trick into a function I want to be able to support both sparse and not sparse operations:
Which I want to use with:
Apologies if i'm being...... dense. (😉) |
Beta Was this translation helpful? Give feedback.
Thanks for the report - this looks like a bug with the newly-added sparse lowering of
lax.mul
.For the moment, I think you can work around it by doing
after which
s_mult(m, m).todense()
will display the correct result.I've opened #8889 to track the bug. Thanks!