-
Hi, very new to jax but so far I've managed to get a ton of mileage out of just using it as a supercharged numpy. I'm trying to run logistic regressions on a series of traits for a group of subjects. Simple example:
Say
I would love to
Is there a better way to handle this situation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Hi, thanks for the question! In this case the best way to proceed is probably to continue operating on fixed-size vectors while accounting for the missing data entries, without constructing intermediate values that have data-dependent shapes. So, for example, you could construct a log-likelihood that skips the missing values, something like this: def loglikelihood(b, X, y):
valid_mask = (y != -9)
return -1 * (jnp.sum(valid_mask * (y * jnp.log(predict(X, b)) + (1 - y) * jnp.log(1 - predict(X, b))))) Then you can vmap over your minimization without having to construct variable shaped arrays. |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question! In this case the best way to proceed is probably to continue operating on fixed-size vectors while accounting for the missing data entries, without constructing intermediate values that have data-dependent shapes. So, for example, you could construct a log-likelihood that skips the missing values, something like this:
Then you can vmap over your minimization without having to construct variable shaped arrays.