Skip to content
Discussion options

You must be logged in to vote

You shouldn't test grad directly on this primitive (grad only works on functions with scalar outputs). You should be testing jax.jvp and jax.vjp.

def batch_norm_grad_rule(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index):
    return batchNormGrad_p.bind(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index)

ad.deflinear(batchNorm_p, batch_norm_grad_rule)

This seems incorrect to me. batch_norm isn't a linear function (deflinear only works correctly for linear functions). If the primitive isn't linear, you need to specify a JVP rule (via ad.primitive_jvps) and a transpose rule (via

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@HolyFalafel
Comment options

@sharadmv
Comment options

Answer selected by HolyFalafel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants