You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[llama-mm] Add torch.cond to replace if condition in MHA (#6869)
* [llama-mm] Add torch.cond to replace if condition in MHA
Summary:
In torchtune's MultiHeadAttention we have this logic:
If `y` is not None, calculate the values of `k` and `v` from y and
update the KVCache.
Otherwise (if `y` is None), retrieve the value of `k` and `v` from
KVCache.
This logic is not able to be handled by export world. Here I'm proposing
a rewrite:
If `y` does not have all values equal to nan (not a number), calculate
the values of `k` and `v` from `y` and update the KVCache.
Otherwise (if all of the values of `y` are nan), retrieve the value of
`k` and `v` from KVCache.
This rewrite allows the module to satisfy the requirement of
`torch.cond` and avoid specialization:
* The operands to `torch.cond` should have the same shape for the true
branch and the false branch.
This means we will have to change this logic in torchtune:
```
if encoder_input is not None:
encoder_embed = self.encoder(**encoder_input)
output = self.decoder(
tokens=tokens,
mask=mask,
encoder_input=encoder_embed,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
```
To be:
```
if encoder_input is not None:
encoder_embed = self.encoder(**encoder_input)
else:
encoder_embed = torch.full_like(encoder_input, torch.nan)
output = self.decoder(
tokens=tokens,
mask=mask,
encoder_input=encoder_embed,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
```
Test Plan: Rely on unit tests
Reviewers:
Subscribers:
Tasks:
Tags:
* Add test
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
211
+
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
0 commit comments