Skip to content

Commit e5fdb66

Browse files
committed
Changes to sdpa and attention module to support vision encoder attention with no kv-cache
1 parent 68c0208 commit e5fdb66

File tree

4 files changed

+105
-19
lines changed

4 files changed

+105
-19
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union
12+
from typing import Tuple, Union, Optional
1313

1414
import torch
1515

@@ -22,20 +22,24 @@
2222
class SDPACustom(torch.nn.Module):
2323
def __init__(
2424
self,
25-
kv_cache: Union[KVCache, QuantizedKVCache],
26-
dim: int,
25+
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
26+
dim: int = -1,
27+
is_causal = True,
2728
):
2829
super().__init__()
2930
# Custom op only supports float32 currently. Converting to/from float32 is
3031
# faster than not having the op.
3132
self.kv_cache = kv_cache
32-
if not isinstance(kv_cache, QuantizedKVCache):
33+
if kv_cache is None:
34+
pass
35+
elif not isinstance(kv_cache, QuantizedKVCache):
3336
self.kv_cache = kv_cache.to(torch.float)
3437
else:
3538
assert (
3639
kv_cache.cache_fp_type == torch.float32
3740
), "Only float32 is supported for custom SDPA"
3841
self.dim = dim
42+
self.is_causal = is_causal
3943

4044
def forward(
4145
self,
@@ -44,8 +48,8 @@ def forward(
4448
k: torch.Tensor,
4549
v: torch.Tensor,
4650
bsz,
47-
seqlen,
48-
mask,
51+
seqlen = None,
52+
mask = None,
4953
):
5054
# Custom op only supports float32 currently. Converting to/from float32 is
5155
# faster than not having the op.
@@ -54,9 +58,20 @@ def forward(
5458
k = k.to(dtype=torch.float)
5559
v = v.to(dtype=torch.float)
5660

57-
k_cache = self.kv_cache.k_cache
58-
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
61+
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
62+
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None
63+
64+
if self.kv_cache is None:
65+
output = torch.ops.llama.custom_sdpa(
66+
q,
67+
k,
68+
v,
69+
input_pos,
70+
None, # Attention mask
71+
0, # dropout probability. Ignored by the code
72+
self.is_causal, # is_causal
73+
)
74+
elif isinstance(self.kv_cache, QuantizedKVCache):
6075
# updated quantize cache, scale and zero points
6176
# returns dequantized kv cache
6277
# Not most optimal. Optimizations to follow next
@@ -68,7 +83,7 @@ def forward(
6883
input_pos[0].item(),
6984
None, # Attention mask
7085
0, # dropout probability. Ignored by the code
71-
True, # is_causal
86+
self.is_causal, # is_causal
7287
)
7388
else:
7489
output = torch.ops.llama.sdpa_with_kv_cache(
@@ -81,7 +96,7 @@ def forward(
8196
seqlen,
8297
None, # Attention mask
8398
0, # dropout probability. Ignored by the code
84-
True, # is_causal
99+
self.is_causal, # is_causal
85100
)
86101
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
87102

@@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99114

100115

101116
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import custom_ops # noqa
117+
from executorch.extension.llm.custom_ops import custom_ops
103118

104119
_replace_sdpa_with_custom_op(module)
105120
return module

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
replace_tiled_token_positional_embedding,
1616
)
1717
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder
18-
18+
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op
1919

2020
@dataclass
2121
class VisionEncoderConfig:
@@ -47,7 +47,7 @@ class VisionEncoderConfig:
4747

4848

4949
class FlamingoVisionEncoderModel(EagerModelBase):
50-
def __init__(self, config: Optional[VisionEncoderConfig] = None):
50+
def __init__(self, config: Optional[VisionEncoderConfig] = None, enable_source_transforms = True):
5151
super().__init__()
5252
if config is None:
5353
config = demo_config
@@ -64,8 +64,8 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
6464
max_num_tiles=config.max_num_tiles,
6565
in_channels=config.in_channels,
6666
)
67-
self.model = replace_tile_positional_embedding(self.model)
68-
self.model = replace_tiled_token_positional_embedding(self.model)
67+
if enable_source_transforms:
68+
self.source_transofrm()
6969
self.image = torch.randn(
7070
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
7171
)
@@ -75,6 +75,12 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
7575
self.aspect_ratio,
7676
)
7777

78+
def source_transofrm(self):
79+
self.model = replace_tile_positional_embedding(self.model)
80+
self.model = replace_tiled_token_positional_embedding(self.model)
81+
self.model = replace_mha_with_inference_mha(self.model)
82+
self.model = replace_sdpa_with_custom_op(self.model)
83+
7884
def get_eager_model(self, **kwargs):
7985
return self.model
8086

examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,56 @@
1616
FlamingoVisionEncoderModel,
1717
)
1818
from torch.testing import assert_close
19-
19+
from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig
20+
from torch._inductor.package import package_aoti
21+
from torch.nn.attention import SDPBackend
22+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
24+
get_symmetric_quantization_config,
25+
XNNPACKQuantizer,
26+
)
27+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
28+
DuplicateDynamicQuantChainPass
29+
)
30+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
31+
XnnpackDynamicallyQuantizedPartitioner,
32+
)
2033

2134
class FlamingoVisionEncoderTest(unittest.TestCase):
2235
def setUp(self) -> None:
2336
super().setUp()
2437

25-
def test_flamingo_vision_encoder(self) -> None:
38+
def test_flamingo_vision_encoder_et(self) -> None:
39+
with torch.no_grad():
40+
vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False)
41+
encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs())
42+
vision_model.source_transofrm()
43+
encoder = vision_model.model
44+
encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs())
45+
assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs)
46+
47+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
48+
training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
49+
assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs()))
50+
51+
dynamic_quantizer = XNNPACKQuantizer()
52+
operator_config_dynamic = get_symmetric_quantization_config(
53+
is_per_channel=True, is_dynamic=True
54+
)
55+
dynamic_quantizer.set_global(operator_config_dynamic)
56+
prepare = prepare_pt2e(training_output.module(), dynamic_quantizer)
57+
prepare(*vision_model.get_example_inputs())
58+
convert = convert_pt2e(prepare)
59+
DuplicateDynamicQuantChainPass()(convert)
60+
61+
export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
62+
63+
edge = to_edge_transform_and_lower(export_output, partitioner=[
64+
XnnpackDynamicallyQuantizedPartitioner(),
65+
], compile_config=EdgeCompileConfig(_check_ir_validity=False))
66+
edge.to_executorch()
67+
68+
def test_flamingo_vision_encoder_aoti(self) -> None:
2669
model = FlamingoVisionEncoderModel()
2770
encoder = model.model
2871
eager_res = encoder.forward(*model.get_example_inputs())

extension/llm/modules/attention.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1515
from torchtune.modules.kv_cache import KVCache
16+
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -310,7 +311,9 @@ def false_fn(y):
310311
self.kv_cache.v_cache.copy_(v)
311312
self.kv_cache.cache_pos.copy_(cache_pos)
312313

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
314+
if input_pos is None:
315+
input_pos = torch.tensor(0)
316+
output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask)
314317
return self.output_proj(output)
315318

316319

@@ -364,6 +367,7 @@ def forward(
364367
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365368
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366369

370+
367371
output = self._attention_fn(
368372
q,
369373
k,
@@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
411415
"""
412416
_replace_mha_with_inference_mha(module)
413417
return module
418+
419+
420+
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
421+
for name, child in module.named_children():
422+
if isinstance(child, SDPA):
423+
setattr(
424+
module,
425+
name,
426+
SDPACustom(is_causal=child.is_causal),
427+
)
428+
else:
429+
_replace_sdpa_with_custom_op(child)
430+
431+
432+
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
433+
from executorch.extension.llm.custom_ops import custom_ops
434+
_replace_sdpa_with_custom_op(module)
435+
return module

0 commit comments

Comments
 (0)