Skip to content

Commit c2ee97e

Browse files
committed
Changes to sdpa and attention module to support vision encoder attention with no kv-cache
1 parent 957259e commit c2ee97e

File tree

4 files changed

+89
-14
lines changed

4 files changed

+89
-14
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union
12+
from typing import Tuple, Union, Optional
1313

1414
import torch
1515

@@ -22,14 +22,16 @@
2222
class SDPACustom(torch.nn.Module):
2323
def __init__(
2424
self,
25-
kv_cache: Union[KVCache, QuantizedKVCache],
26-
dim: int,
25+
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
26+
dim: int = -1,
2727
):
2828
super().__init__()
2929
# Custom op only supports float32 currently. Converting to/from float32 is
3030
# faster than not having the op.
3131
self.kv_cache = kv_cache
32-
if not isinstance(kv_cache, QuantizedKVCache):
32+
if kv_cache is None:
33+
pass
34+
elif not isinstance(kv_cache, QuantizedKVCache):
3335
self.kv_cache = kv_cache.to(torch.float)
3436
else:
3537
assert (
@@ -44,8 +46,8 @@ def forward(
4446
k: torch.Tensor,
4547
v: torch.Tensor,
4648
bsz,
47-
seqlen,
48-
mask,
49+
seqlen = None,
50+
mask = None,
4951
):
5052
# Custom op only supports float32 currently. Converting to/from float32 is
5153
# faster than not having the op.
@@ -54,9 +56,20 @@ def forward(
5456
k = k.to(dtype=torch.float)
5557
v = v.to(dtype=torch.float)
5658

57-
k_cache = self.kv_cache.k_cache
58-
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
59+
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
60+
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None
61+
62+
if self.kv_cache is None:
63+
output = torch.ops.llama.custom_sdpa(
64+
q,
65+
k,
66+
v,
67+
input_pos,
68+
None, # Attention mask
69+
0, # dropout probability. Ignored by the code
70+
False, # is_causal
71+
)
72+
elif isinstance(self.kv_cache, QuantizedKVCache):
6073
# updated quantize cache, scale and zero points
6174
# returns dequantized kv cache
6275
# Not most optimal. Optimizations to follow next
@@ -99,7 +112,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99112

100113

101114
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import custom_ops # noqa
115+
from executorch.extension.llm.custom_ops import custom_ops
103116

104117
_replace_sdpa_with_custom_op(module)
105118
return module

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
replace_tiled_token_positional_embedding,
1616
)
1717
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder
18-
18+
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op
1919

2020
@dataclass
2121
class VisionEncoderConfig:
@@ -66,6 +66,8 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
6666
)
6767
self.model = replace_tile_positional_embedding(self.model)
6868
self.model = replace_tiled_token_positional_embedding(self.model)
69+
self.model = replace_mha_with_inference_mha(self.model)
70+
self.model = replace_sdpa_with_custom_op(self.model)
6971
self.image = torch.randn(
7072
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
7173
)

examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,53 @@
1616
FlamingoVisionEncoderModel,
1717
)
1818
from torch.testing import assert_close
19-
19+
from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig
20+
from torch._inductor.package import package_aoti
21+
from torch.nn.attention import SDPBackend
22+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
24+
get_symmetric_quantization_config,
25+
XNNPACKQuantizer,
26+
)
27+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
28+
DuplicateDynamicQuantChainPass
29+
)
30+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
31+
XnnpackDynamicallyQuantizedPartitioner,
32+
)
2033

2134
class FlamingoVisionEncoderTest(unittest.TestCase):
2235
def setUp(self) -> None:
2336
super().setUp()
2437

25-
def test_flamingo_vision_encoder(self) -> None:
38+
def test_flamingo_vision_encoder_et(self) -> None:
39+
with torch.no_grad():
40+
model = FlamingoVisionEncoderModel()
41+
encoder = model.model
42+
encoder.forward(*model.get_example_inputs())
43+
44+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
45+
training_output = torch.export.export_for_training(encoder, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes())
46+
assert_close(encoder(*model.get_example_inputs()), training_output.module()(*model.get_example_inputs()))
47+
48+
dynamic_quantizer = XNNPACKQuantizer()
49+
operator_config_dynamic = get_symmetric_quantization_config(
50+
is_per_channel=True, is_dynamic=True
51+
)
52+
dynamic_quantizer.set_global(operator_config_dynamic)
53+
prepare = prepare_pt2e(training_output.module(), dynamic_quantizer)
54+
prepare(*model.get_example_inputs())
55+
convert = convert_pt2e(prepare)
56+
DuplicateDynamicQuantChainPass()(convert)
57+
58+
export_output = torch.export.export(convert, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes())
59+
60+
edge = to_edge_transform_and_lower(export_output, partitioner=[
61+
XnnpackDynamicallyQuantizedPartitioner(),
62+
], compile_config=EdgeCompileConfig(_check_ir_validity=False))
63+
edge.to_executorch()
64+
65+
def test_flamingo_vision_encoder_aoti(self) -> None:
2666
model = FlamingoVisionEncoderModel()
2767
encoder = model.model
2868
eager_res = encoder.forward(*model.get_example_inputs())

extension/llm/modules/attention.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1515
from torchtune.modules.kv_cache import KVCache
16+
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -310,7 +311,7 @@ def false_fn(y):
310311
self.kv_cache.v_cache.copy_(v)
311312
self.kv_cache.cache_pos.copy_(cache_pos)
312313

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
314+
output = self._sdpa(0, q, k, v, b, s_x)
314315
return self.output_proj(output)
315316

316317

@@ -364,6 +365,7 @@ def forward(
364365
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365366
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366367

368+
367369
output = self._attention_fn(
368370
q,
369371
k,
@@ -411,3 +413,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
411413
"""
412414
_replace_mha_with_inference_mha(module)
413415
return module
416+
417+
418+
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
419+
for name, child in module.named_children():
420+
if isinstance(child, SDPA):
421+
setattr(
422+
module,
423+
name,
424+
SDPACustom(),
425+
)
426+
else:
427+
_replace_sdpa_with_custom_op(child)
428+
429+
430+
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
431+
from executorch.extension.llm.custom_ops import custom_ops
432+
_replace_sdpa_with_custom_op(module)
433+
return module

0 commit comments

Comments
 (0)