-
I found this awesome project recently and I'm trying to use the fla layers in a non-LLM task, where we have a very long sequence and only the hidden state of the last "token" is useful. The current recurrent kernels, for example gated_deltanet, always returns the hidden state of every token, that will allocate huge memory. Is there any way except call the kernel token by token in a for loop that can avoid the memory allocation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
@Fadelis98 Hey
This is not true as we only materialize the last hidden state |
Beta Was this translation helpful? Give feedback.
For multi-layer models, it still needs all the hidden-states as input to the next layer, so it may not be worth changing all the kernels for this feature, which also causes Triton to compile the kernel multiple times.