-
Notifications
You must be signed in to change notification settings - Fork 238
Question regarding the cookbook #200
Description
Hi there, thanks for the great repo!
I was working through the Neural Tangents Cookbook and am a bit confused by the loss_fn (reproduced below):
def loss_fn(predict_fn, ys, t, xs=None):
mean, cov = predict_fn(t=t, get='ntk', x_test=xs, compute_cov=True)
mean = jnp.reshape(mean, mean.shape[:1] + (-1,))
var = jnp.diagonal(cov, axis1=1, axis2=2)
ys = jnp.reshape(ys, (1, -1))
mean_predictions = 0.5 * jnp.mean(ys ** 2 - 2 * mean * ys + var + mean ** 2,
axis=1)
return mean_predictions
It looks like this function is later used to calculate the training or test losses for plotting. What I am confused by is, the calculation for (each test point in) the mean_predictions contains var, making it effectively the sum of the squared error (between a prediction and a label) and the variance. While it does make sense to include the variance as part of the performance (or loss), but why this speicfic form (e.g., why
Thanks again!