Commit 4c138ef
committed
[llama-mm] Add torch.cond to replace if condition in MHA
Summary:
In torchtune's MultiHeadAttention we have this logic:
If `y` is not None, calculate the values of `k` and `v` from y and
update the KVCache.
Otherwise (if `y` is None), retrieve the value of `k` and `v` from
KVCache.
This logic is not able to be handled by export world. Here I'm proposing
a rewrite:
If `y` does not have all values equal to nan (not a number), calculate
the values of `k` and `v` from `y` and update the KVCache.
Otherwise (if all of the values of `y` are nan), retrieve the value of
`k` and `v` from KVCache.
This rewrite allows the module to satisfy the requirement of
`torch.cond` and avoid specialization:
* The operands to `torch.cond` should have the same shape for the true
branch and the false branch.
This means we will have to change this logic in torchtune:
```
if encoder_input is not None:
encoder_embed = self.encoder(**encoder_input)
output = self.decoder(
tokens=tokens,
mask=mask,
encoder_input=encoder_embed,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
```
To be:
```
if encoder_input is not None:
encoder_embed = self.encoder(**encoder_input)
else:
encoder_embed = torch.full_like(encoder_input, torch.nan)
output = self.decoder(
tokens=tokens,
mask=mask,
encoder_input=encoder_embed,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
```
Test Plan: Rely on unit tests
Reviewers:
Subscribers:
Tasks:
Tags:1 parent 71612a6 commit 4c138ef
File tree
3 files changed
+51
-12
lines changed- extension/llm/modules
- test
3 files changed
+51
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
246 | 246 | | |
247 | 247 | | |
248 | 248 | | |
249 | | - | |
250 | 249 | | |
251 | 250 | | |
252 | 251 | | |
| |||
263 | 262 | | |
264 | 263 | | |
265 | 264 | | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
| 265 | + | |
274 | 266 | | |
275 | | - | |
| 267 | + | |
276 | 268 | | |
277 | 269 | | |
278 | 270 | | |
| |||
288 | 280 | | |
289 | 281 | | |
290 | 282 | | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
291 | 294 | | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
292 | 308 | | |
293 | | - | |
294 | | - | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
295 | 312 | | |
296 | 313 | | |
297 | 314 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
127 | 127 | | |
128 | 128 | | |
129 | 129 | | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| 44 | + | |
44 | 45 | | |
45 | 46 | | |
46 | 47 | | |
| 48 | + | |
47 | 49 | | |
48 | 50 | | |
49 | 51 | | |
| |||
0 commit comments