Skip to content
Discussion options

You must be logged in to vote

haiku is similar

class MyModel(hk.Module):
    def __init__(self):
        self.bn = hk.BatchNorm(True, True, 0.99)
        self.linear = hk.Linear(16)
    def __call__(self, x, is_training):
        return self.bn(self.linear(x), is_training)

train_forward = hk.transform(lambda x: MyModel()(x, True))
eval_forward = hk.transform(lambda x: MyModel()(x, False))
params = train_forward.init(rng_key, x=x)
train_forward.apply(params, x=x, rng=rng_key)
eval_forward.apply(params, x=x, rng=rng_key)

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
1 reply
@marcuswang6
Comment options

Comment options

You must be logged in to vote
3 replies
@YouJiacheng
Comment options

@marcuswang6
Comment options

@nalzok
Comment options

Answer selected by marcuswang6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants