Skip to content

Commit 05c1aa2

Browse files
committed
Changes to sdpa and attention module to support vision encoder attention with no kv-cache
1 parent 957259e commit 05c1aa2

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union
12+
from typing import Tuple, Union, Optional
1313

1414
import torch
1515

@@ -22,14 +22,16 @@
2222
class SDPACustom(torch.nn.Module):
2323
def __init__(
2424
self,
25-
kv_cache: Union[KVCache, QuantizedKVCache],
26-
dim: int,
25+
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
26+
dim: int = -1,
2727
):
2828
super().__init__()
2929
# Custom op only supports float32 currently. Converting to/from float32 is
3030
# faster than not having the op.
3131
self.kv_cache = kv_cache
32-
if not isinstance(kv_cache, QuantizedKVCache):
32+
if kv_cache is None:
33+
pass
34+
elif not isinstance(kv_cache, QuantizedKVCache):
3335
self.kv_cache = kv_cache.to(torch.float)
3436
else:
3537
assert (
@@ -44,8 +46,8 @@ def forward(
4446
k: torch.Tensor,
4547
v: torch.Tensor,
4648
bsz,
47-
seqlen,
48-
mask,
49+
seqlen = None,
50+
mask = None,
4951
):
5052
# Custom op only supports float32 currently. Converting to/from float32 is
5153
# faster than not having the op.
@@ -54,9 +56,20 @@ def forward(
5456
k = k.to(dtype=torch.float)
5557
v = v.to(dtype=torch.float)
5658

57-
k_cache = self.kv_cache.k_cache
58-
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
59+
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
60+
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None
61+
62+
if self.kv_cache is None:
63+
output = torch.ops.llama.custom_sdpa(
64+
q,
65+
k,
66+
v,
67+
input_pos,
68+
None, # Attention mask
69+
0, # dropout probability. Ignored by the code
70+
False, # is_causal
71+
)
72+
elif isinstance(self.kv_cache, QuantizedKVCache):
6073
# updated quantize cache, scale and zero points
6174
# returns dequantized kv cache
6275
# Not most optimal. Optimizations to follow next
@@ -99,7 +112,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99112

100113

101114
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import custom_ops # noqa
115+
from executorch.extension.llm.custom_ops import custom_ops
103116

104117
_replace_sdpa_with_custom_op(module)
105118
return module

extension/llm/modules/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def false_fn(y):
310310
self.kv_cache.v_cache.copy_(v)
311311
self.kv_cache.cache_pos.copy_(cache_pos)
312312

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
313+
output = self._sdpa(0, q, k, v, b, s_x)
314314
return self.output_proj(output)
315315

316316

@@ -364,6 +364,7 @@ def forward(
364364
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365365
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366366

367+
367368
output = self._attention_fn(
368369
q,
369370
k,

0 commit comments

Comments
 (0)