Skip to content

Commit c58b2c0

Browse files
committed
Changes to sdpa and attention module to support vision encoder attention with no kv-cache
1 parent 957259e commit c58b2c0

File tree

4 files changed

+99
-17
lines changed

4 files changed

+99
-17
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union
12+
from typing import Tuple, Union, Optional
1313

1414
import torch
1515

@@ -22,14 +22,16 @@
2222
class SDPACustom(torch.nn.Module):
2323
def __init__(
2424
self,
25-
kv_cache: Union[KVCache, QuantizedKVCache],
26-
dim: int,
25+
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
26+
dim: int = -1,
2727
):
2828
super().__init__()
2929
# Custom op only supports float32 currently. Converting to/from float32 is
3030
# faster than not having the op.
3131
self.kv_cache = kv_cache
32-
if not isinstance(kv_cache, QuantizedKVCache):
32+
if kv_cache is None:
33+
pass
34+
elif not isinstance(kv_cache, QuantizedKVCache):
3335
self.kv_cache = kv_cache.to(torch.float)
3436
else:
3537
assert (
@@ -44,8 +46,8 @@ def forward(
4446
k: torch.Tensor,
4547
v: torch.Tensor,
4648
bsz,
47-
seqlen,
48-
mask,
49+
seqlen = None,
50+
mask = None,
4951
):
5052
# Custom op only supports float32 currently. Converting to/from float32 is
5153
# faster than not having the op.
@@ -54,9 +56,20 @@ def forward(
5456
k = k.to(dtype=torch.float)
5557
v = v.to(dtype=torch.float)
5658

57-
k_cache = self.kv_cache.k_cache
58-
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
59+
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
60+
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None
61+
62+
if self.kv_cache is None:
63+
output = torch.ops.llama.custom_sdpa(
64+
q,
65+
k,
66+
v,
67+
input_pos,
68+
None, # Attention mask
69+
0, # dropout probability. Ignored by the code
70+
False, # is_causal
71+
)
72+
elif isinstance(self.kv_cache, QuantizedKVCache):
6073
# updated quantize cache, scale and zero points
6174
# returns dequantized kv cache
6275
# Not most optimal. Optimizations to follow next
@@ -99,7 +112,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99112

100113

101114
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import custom_ops # noqa
115+
from executorch.extension.llm.custom_ops import custom_ops
103116

104117
_replace_sdpa_with_custom_op(module)
105118
return module

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
replace_tiled_token_positional_embedding,
1616
)
1717
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder
18-
18+
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op
1919

2020
@dataclass
2121
class VisionEncoderConfig:
@@ -47,7 +47,7 @@ class VisionEncoderConfig:
4747

4848

4949
class FlamingoVisionEncoderModel(EagerModelBase):
50-
def __init__(self, config: Optional[VisionEncoderConfig] = None):
50+
def __init__(self, config: Optional[VisionEncoderConfig] = None, enable_source_transforms = True):
5151
super().__init__()
5252
if config is None:
5353
config = demo_config
@@ -64,8 +64,8 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
6464
max_num_tiles=config.max_num_tiles,
6565
in_channels=config.in_channels,
6666
)
67-
self.model = replace_tile_positional_embedding(self.model)
68-
self.model = replace_tiled_token_positional_embedding(self.model)
67+
if enable_source_transforms:
68+
self.source_transofrm()
6969
self.image = torch.randn(
7070
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
7171
)
@@ -75,6 +75,12 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
7575
self.aspect_ratio,
7676
)
7777

78+
def source_transofrm(self):
79+
self.model = replace_tile_positional_embedding(self.model)
80+
self.model = replace_tiled_token_positional_embedding(self.model)
81+
self.model = replace_mha_with_inference_mha(self.model)
82+
self.model = replace_sdpa_with_custom_op(self.model)
83+
7884
def get_eager_model(self, **kwargs):
7985
return self.model
8086

examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,56 @@
1616
FlamingoVisionEncoderModel,
1717
)
1818
from torch.testing import assert_close
19-
19+
from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig
20+
from torch._inductor.package import package_aoti
21+
from torch.nn.attention import SDPBackend
22+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
24+
get_symmetric_quantization_config,
25+
XNNPACKQuantizer,
26+
)
27+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
28+
DuplicateDynamicQuantChainPass
29+
)
30+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
31+
XnnpackDynamicallyQuantizedPartitioner,
32+
)
2033

2134
class FlamingoVisionEncoderTest(unittest.TestCase):
2235
def setUp(self) -> None:
2336
super().setUp()
2437

25-
def test_flamingo_vision_encoder(self) -> None:
38+
def test_flamingo_vision_encoder_et(self) -> None:
39+
with torch.no_grad():
40+
vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False)
41+
encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs())
42+
vision_model.source_transofrm()
43+
encoder = vision_model.model
44+
encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs())
45+
assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs)
46+
47+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
48+
training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
49+
assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs()))
50+
51+
dynamic_quantizer = XNNPACKQuantizer()
52+
operator_config_dynamic = get_symmetric_quantization_config(
53+
is_per_channel=True, is_dynamic=True
54+
)
55+
dynamic_quantizer.set_global(operator_config_dynamic)
56+
prepare = prepare_pt2e(training_output.module(), dynamic_quantizer)
57+
prepare(*vision_model.get_example_inputs())
58+
convert = convert_pt2e(prepare)
59+
DuplicateDynamicQuantChainPass()(convert)
60+
61+
export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
62+
63+
edge = to_edge_transform_and_lower(export_output, partitioner=[
64+
XnnpackDynamicallyQuantizedPartitioner(),
65+
], compile_config=EdgeCompileConfig(_check_ir_validity=False))
66+
edge.to_executorch()
67+
68+
def test_flamingo_vision_encoder_aoti(self) -> None:
2669
model = FlamingoVisionEncoderModel()
2770
encoder = model.model
2871
eager_res = encoder.forward(*model.get_example_inputs())

extension/llm/modules/attention.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1515
from torchtune.modules.kv_cache import KVCache
16+
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -310,7 +311,7 @@ def false_fn(y):
310311
self.kv_cache.v_cache.copy_(v)
311312
self.kv_cache.cache_pos.copy_(cache_pos)
312313

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
314+
output = self._sdpa(0, q, k, v, b, s_x)
314315
return self.output_proj(output)
315316

316317

@@ -364,6 +365,7 @@ def forward(
364365
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365366
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366367

368+
367369
output = self._attention_fn(
368370
q,
369371
k,
@@ -411,3 +413,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
411413
"""
412414
_replace_mha_with_inference_mha(module)
413415
return module
416+
417+
418+
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
419+
for name, child in module.named_children():
420+
if isinstance(child, SDPA):
421+
setattr(
422+
module,
423+
name,
424+
SDPACustom(),
425+
)
426+
else:
427+
_replace_sdpa_with_custom_op(child)
428+
429+
430+
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
431+
from executorch.extension.llm.custom_ops import custom_ops
432+
_replace_sdpa_with_custom_op(module)
433+
return module

0 commit comments

Comments
 (0)