Creating Jax Primitive that uses mhlo.BatchNormTraining op #13228
-
I'm trying to create a Jax Primitive using XLA's mhlo::BatchNormTraining op. The code is pretty simple: from jax import lax, dtypes
from jax.interpreters import ad, mlir
from jax._src.lax import lax as lx
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir import ir
from functools import partial
import numpy as np
def batchnorm_weak_types_rule(batch_mean, batch_var, operand, scale, offset, *, epsilon, feature_index):
return (operand.weak_type, batch_mean.weak_type, batch_var.weak_type)
def batchnorm_shape_rule(batch_mean, batch_var, operand, scale, offset, *, epsilon, feature_index):
return (operand.shape, batch_mean.shape, batch_var.shape)
def batchnorm_named_shape_rule(batch_mean, batch_var, operand, scale, offset, *, epsilon, feature_index):
return (operand.named_shape, batch_mean.named_shape, batch_var.named_shape)
def batchnorm_dtype_rule(batch_mean, batch_var, operand, scale, offset, *, epsilon, feature_index):
if not dtypes.issubdtype(operand.dtype, np.number):
raise TypeError("XLA BatchNorm does not accept dtype {}. Accepted dtypes are subtypes "
"of number.".format(np.dtype(operand.dtype).name))
if not dtypes.issubdtype(batch_mean.dtype, np.number):
raise TypeError("XLA BatchNorm does not accept dtype {}. Accepted dtypes are subtypes "
"of number.".format(np.dtype(batch_mean.dtype).name))
if not dtypes.issubdtype(batch_var.dtype, np.number):
raise TypeError("XLA BatchNorm does not accept dtype {}. Accepted dtypes are subtypes "
"of number.".format(np.dtype(batch_var.dtype).name))
return (dtypes.canonicalize_dtype(operand.dtype), dtypes.canonicalize_dtype(batch_mean.dtype), dtypes.canonicalize_dtype(batch_var.dtype))
batchNorm_p = lax.standard_primitive(
batchnorm_named_shape_rule, batchnorm_dtype_rule,
'local_batch_norm')
batchNorm_p.multiple_results = True
batchNorm_p.def_abstract_eval(
partial(lx.standard_multi_result_abstract_eval, batchNorm_p, batchnorm_shape_rule,
batchnorm_dtype_rule, batchnorm_weak_types_rule,
batchnorm_named_shape_rule))
# not yet checked, probably wrong
batching.primitive_batchers[batchNorm_p] = lx._reduce_batch_rule
# not yet checked, probably wrong
batchNormGrad_p = lax.standard_primitive(
# _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
# lx.standard_named_shape_rule, _conv_general_dilated_dtype_rule,
lx._reduce_op_shape_rule, lx._reduce_number_dtype_rule,
'local_batch_norm_grad')
def batch_norm_train(batch_mean, batch_var, operand, scale, offset, *, epsilon, feature_index):
return batchNorm_p.bind(batch_mean, batch_var, operand, scale, offset, epsilon=epsilon, feature_index=feature_index)
def batch_norm_grad_rule(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index):
return batchNormGrad_p.bind(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index)
ad.deflinear(batchNorm_p, batch_norm_grad_rule)
def lowerBatchNorm(ctx, batch_mean, batch_var, operand, scale, offset, *, epsilon, feature_index):
aval_out = ctx.avals_in[2]
aval_out_mean = ctx.avals_in[3]
aval_out_var = ctx.avals_in[4]
return mhlo.BatchNormTrainingOp(mlir.aval_to_ir_type(aval_out), mlir.aval_to_ir_type(aval_out_mean), mlir.aval_to_ir_type(aval_out_var), operand, scale, offset, ir.FloatAttr.get(ir.F32Type.get(), epsilon), mlir.i64_attr(feature_index)).results
def lowerBatchNormGrad(ctx, grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, *, epsilon, feature_index):
return mhlo.BatchNormGradOp(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index).results
mlir.register_lowering(batchNorm_p, lowerBatchNorm)
mlir.register_lowering(batchNormGrad_p, lowerBatchNormGrad) The forward step seems to work and return the results:
I'm currently struggling with api.grad() function. I get an error that looks like this (The shape is simplified from the original message):
I'm trying to find some references to grad that are performed on multiple results, but couldn't find many.
But it's calling api.jvp() and api.vjp() functions directly, and also I'm not sure it's the right way to go. Any help would be appreciated :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
You shouldn't test def batch_norm_grad_rule(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index):
return batchNormGrad_p.bind(grad_operand, grad_scale, grad_offset, operand, scale, mean, variance, grad_output, epsilon, feature_index)
ad.deflinear(batchNorm_p, batch_norm_grad_rule) This seems incorrect to me. However, since you don't have a |
Beta Was this translation helpful? Give feedback.
You shouldn't test
grad
directly on this primitive (grad
only works on functions with scalar outputs). You should be testingjax.jvp
andjax.vjp
.This seems incorrect to me.
batch_norm
isn't a linear function (deflinear
only works correctly for linear functions). If the primitive isn't linear, you need to specify a JVP rule (viaad.primitive_jvps
) and a transpose rule (via…