Accumulation of Monte-Carlo gradients within a flax module to avoid OOM error #11528
Unanswered
etienne-thuillier
asked this question in
Q&A
Replies: 1 comment 1 reply
-
|
Check this for how to convert a class Model(nn.Module):
@nn.compact
def __call__(self, ...):
...
decoder = Decoder(sigma_floor=1.0e-3, parent=None)
apply_fn = decoder.apply
decoder_params = self.param('mlp', decoder_params, decoder.init)
if self.is_mutable_collection('params'): # Optional, avoid useless apply_fn in initialization
apply_fn = lambda: _, x: dummy_eval_of_decoder(x)
# dummy eval only need to output same shape and dtype
# dummy_eval_of_decoder(z).shape == apply_fn(decoder_params, z).shape
# dummy_eval_of_decoder(z).dtype == apply_fn(decoder_params, z).dtype
# now apply_fn is a pure function full compatible with JAXYou can also use lifted transformation provided by Flax |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I noticed that I posted this at the wrong place. And so I copied the question in flax's Q&A
google/flax#2301
I did not find a way to delete this thread...
Beta Was this translation helpful? Give feedback.
All reactions