How to change my jax code after I add a BatchNorm Layer? #10707
Answered
by
YouJiacheng
marcuswang6
asked this question in
Q&A
-
Due to parameter 'is_training' of BatchNorm call() , I have to change my jax code |
Beta Was this translation helpful? Give feedback.
Answered by
YouJiacheng
May 14, 2022
Replies: 2 comments 4 replies
-
assume you use class Model(nn.Module):
is_training: bool
@nn.compact
def __call__(x):
return SomeModuleDependsOnWhetherTraining()(x, is_training)
train_model_forward = Model(True).apply
eval_model_forward = Model(False).apply |
Beta Was this translation helpful? Give feedback.
1 reply
-
haiku is similar class MyModel(hk.Module):
def __init__(self):
self.bn = hk.BatchNorm(True, True, 0.99)
self.linear = hk.Linear(16)
def __call__(self, x, is_training):
return self.bn(self.linear(x), is_training)
train_forward = hk.transform(lambda x: MyModel()(x, True))
eval_forward = hk.transform(lambda x: MyModel()(x, False))
params = train_forward.init(rng_key, x=x)
train_forward.apply(params, x=x, rng=rng_key)
eval_forward.apply(params, x=x, rng=rng_key) |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
marcuswang6
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
haiku is similar