Remat inside remat #20743
Unanswered
LeoXinhaoLee
asked this question in
General
Remat inside remat
#20743
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm trying to scan through a sequence by groups of chunks, and use
remat
to avoid saving intermediate variables produced during the processing of each group of chunks.Specifically, in the below code,
compute_chunk
processes a chunk of tokens, andcompute_group
processes a group of chunks. I'm wondering:To avoid saving all
Y
, do we need theremat
decorator oncompute_chunk
? Or theremat
oncompute_group
already has effect oncompute_chunk
? If it's needed, how should we setprevent_cse
forcompute_chunk
?The
remat
oncompute_chunk
will save theW
of each chunk since it's the output ofcompute_chunk
. However, theremat
oncompute_group
only saves the finalW
of a group of chunks. How does JAX handle this conflict?Thank you very much for your time and help! @jakevdp
Beta Was this translation helpful? Give feedback.
All reactions