Sparse extraction - memory usage issue #16291
-
I am having a bit of trouble with a K = jsparse.BCOO((sK, ijK), shape=(ndof, ndof))
K_free = K[free,:][:,free] It works, but when tracking memory usage, the Is there another better way of extracting the values I need in a more efficient manner? The function I am calculating the gradient of, with respect to x : def obj_compliance(x,nelx,nely,free,penal,Emax,Emin,f,H,Hs,ijK):
ndof = 2*(nelx+1)*(nely+1)
u=jnp.zeros((ndof))
xPhys = H.T @ (x.T/Hs)
KE=lk()
sK=((KE.flatten()[np.newaxis]).T*(Emin+(xPhys)**penal*(Emax-Emin))).flatten(order='F')
K = jsparse.BCOO((sK, ijK), shape=(ndof, ndof))
K_free = K[free,:][:,free]
u_solve = klusolve(K_free.indices[:,0],K_free.indices[:,1],K_free.data,f[free, 0])
compliance = f.T @ u.at[free].add(u_solve)
return compliance.sum()
(compliance, dc) = value_and_grad(obj_compliance, argnums=0)(x,nelx,nely,free,penal,Emax,Emin,f,H,Hs,ijK) Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
The issue is that "extracting values from sparse arrays" is fundamentally a set-join operation, and XLA has no primitives for set arithmetic. As a result, I don't really have any better suggestions (if I did, I would have put them in the BCOO implementation!) but you may be able to do better for your specific application by writing the low-level index manipulations directly. For example, it looks like you're only concerned with extracting values along the diagonal, which you may be able to exploit to do so more efficiently. For what it's worth, this is one of the (many) reasons why JAX sparse is still under |
Beta Was this translation helpful? Give feedback.
-
Hi, Adam. I have the same problem as you. Do you have an elegant solution? @adambomandel |
Beta Was this translation helpful? Give feedback.
The issue is that "extracting values from sparse arrays" is fundamentally a set-join operation, and XLA has no primitives for set arithmetic. As a result,
BCOO
has to construct the operation using available primitives. For lack of a better option, it essentially doesjnp,any(query[:, None] = indices[None, :], -1)
, which has very poor memory scaling as the size of the query and indices increase.I don't really have any better suggestions (if I did, I would have put them in the BCOO implementation!) but you may be able to do better for your specific application by writing the low-level index manipulations directly. For example, it looks like you're only concerned with extracting values alon…