Replies: 2 comments 1 reply
-
I would probably do this by changing your calculation so that it does not require constructing an explicit array of nonzero values; e.g. def mse_observed_loss_2(A, params):
U, V = params['users'], params['items']
estimator = -(U @ V.T)
square_err_mat = jnp.multiply(A + estimator, A + estimator)
nonzero = (A != 0)
return jnp.where(nonzero, square_err_mat, 0).sum() / nonzero.sum()
jit(mse_observed_loss_2)(a, {'users': u, 'items': v})
# DeviceArray(80.666664, dtype=float32) If you want a more compact version of this same logic, you can use the def mse_observed_loss_3(A, params):
U, V = params['users'], params['items']
estimator = -(U @ V.T)
return jnp.multiply(A + estimator, A + estimator).mean(where=(A != 0)) |
Beta Was this translation helpful? Give feedback.
1 reply
-
I wanted to report back because I went ahead and did some performance testing (on CPU). Here's the leaderboard:
Edit:
Quick comments:
Everything is still in the same notebook here |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I was working on some Matrix Factorization demos in JAX, and I realized that if I want to use only the observed entries to compute loss, I needed to call
nonzero
in part of my jit-ed loss function, which requires me to calc the number of nonzeroes... which yields the error:This led to me playing with several different approaches to dealing with this problem and in the end I have about 4.5ish ways to solve the problem.
I discuss and demonstrate in the Colab here.
Question:
Additional questions:
Beta Was this translation helpful? Give feedback.
All reactions