-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Open
Labels
Description
Title: Pipeline synchronization issue in Flash Attention - S1 pipeline held during O0 rescale
Description:
While working on Flash Attention implementation using Cutlass pipelines and I'm encountering a synchronization issue that I don't understand.
In the correction function, when performing rescale(O0), the S1 pipeline appears to be held (not released) until after the O0 rescale completes. Same for S0 and O1 rescale. This seems to create unnecessary pipeline stalls, specially in causal/local masks.
Code context:
// After consuming S1 pipeline state
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
....
// During O0 rescale - S1 is NOT released here
// pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); // Commented out
// ++pipeline_s1_c_consumer_state;
// S1 only gets released after O0 rescale
correction_rescale(scale, uint32_t(TmemAllocation::O0));
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;Questions:
- Why can't S1 be released immediately after it's consumed, before the O0 rescale?
- Is there a dependency between S1 and O0 that I'm missing?
Reactions are currently unavailable