You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR targets an issue created by Tom Parnell with the following description:
"Currently in the PT2C warmup logic we essentially perform warmup twice, one with as_concat=False and again with as_concat=True. It was implemented this way because we see some differences between "normal" batches and batches that were created from concatenation. The warmup logic essentially tries to cover both of these two cases.
Specifically, the differences between normal batches and post-concat baches are as follows:
1. Post-concat batches always have contiguous PKV tensors, whereas "normal" batches have contiguous PKV tensors almost all of the time but very occasionally (e..g, after very first token is generated) have non-contiguous PKV tensors.
2. Post-concat batches contain the decoder_attention_mask tensor (for encoder-decoder models) whereas for normal batches it is set to None.
The issue relates to the following work: can we make some small code changes to essentially regularize these two cases?
Since the PKV tensors are only rarely non-contiguous, can't we just force them to be contiguous before calling forward? There is some latency penalty to doing this but since most of the time it is not needed, we might be ok.
Can be also define the decoder_attention_mask for "normal" batches. Again, perhaps there is some small latency overhead from this which needs to be evaluated.
These changes may incur a potential latency cost but will have the benefit of halving the warmup time. The work here is to (a) implement these changes and (b) verify that the latency overhead is minimal."
The update involves the required small code changes as described above.
Co-authored-by: Thomas Parnell <[email protected]>
0 commit comments