Lowering of normalization call when using flax #19921
Unanswered
pratnali-aws
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello All,
This is a newbie question:
I have the following simple model that uses BatchNormalization:
and it is lowered to
I see that
nn.Conv
gets lowered toconvolution.23 = f32[1,64,64,32]{3,2,1,0} convolution(Arg_6.20, Arg_5.19), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
butnn.BatchNorm
is lowered into a sequence of ops.I get a sense that it is a translation of implementation here.
Is there some way to avoid this? Especially, since hlo has a native
batch_norm_training
op?Beta Was this translation helpful? Give feedback.
All reactions