Commit 9bdd6dc
ssjia
Update on "[ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache"
SDPA used to be handled by a custom op `sdpa_with_kv_cache`, but it was eventually split (D62301837) into update_cache and custom_sdpa ops.
However, having a single fused op is useful for Vulkan since it allows more control over how the cache tensors are stored and represented. Essentially, it makes it easier to manage the cache tensors and opens up opportunities for future optimizations. This diff introduces a fusion pass that does 2 things:
1. Combine update_cache and custom_sdpa back into sdpa_with_kv_cache
2. Ensure all references to the cache_pos symint use the same node - this prevents the select_at_dim_as_symint op from being called every time it is used.
Differential Revision: [D86340339](https://our.internmc.facebook.com/intern/diff/D86340339/)
[ghstack-poisoned]2 files changed
+20
-19
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
| 10 | + | |
9 | 11 | | |
10 | 12 | | |
11 | 13 | | |
| |||
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
18 | | - | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
22 | | - | |
23 | | - | |
| 23 | + | |
24 | 24 | | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | 25 | | |
| 26 | + | |
| 27 | + | |
32 | 28 | | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | 29 | | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
| 30 | + | |
| 31 | + | |
43 | 32 | | |
44 | 33 | | |
45 | 34 | | |
| |||
97 | 86 | | |
98 | 87 | | |
99 | 88 | | |
100 | | - | |
| 89 | + | |
101 | 90 | | |
102 | 91 | | |
103 | 92 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
373 | 373 | | |
374 | 374 | | |
375 | 375 | | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
376 | 388 | | |
377 | 389 | | |
378 | 390 | | |
| |||
0 commit comments