Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Example script for exporting Llama2 to flatbuffer

import math
from typing import Tuple, Union
from typing import Tuple, Union, Optional

import torch

Expand All @@ -22,20 +22,24 @@
class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: Union[KVCache, QuantizedKVCache],
dim: int,
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
dim: int = -1,
is_causal = True,
):
super().__init__()
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
self.kv_cache = kv_cache
if not isinstance(kv_cache, QuantizedKVCache):
if kv_cache is None:
pass
elif not isinstance(kv_cache, QuantizedKVCache):
self.kv_cache = kv_cache.to(torch.float)
else:
assert (
kv_cache.cache_fp_type == torch.float32
), "Only float32 is supported for custom SDPA"
self.dim = dim
self.is_causal = is_causal

def forward(
self,
Expand All @@ -44,8 +48,8 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
mask,
seqlen = None,
mask = None,
):
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
Expand All @@ -54,9 +58,20 @@ def forward(
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if hasattr(self.kv_cache, "quantized_cache_dtype"):
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None

if self.kv_cache is None:
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
input_pos,
None, # Attention mask
0, # dropout probability. Ignored by the code
self.is_causal, # is_causal
)
elif isinstance(self.kv_cache, QuantizedKVCache):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
Expand All @@ -68,7 +83,7 @@ def forward(
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
self.is_causal, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -81,7 +96,7 @@ def forward(
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
self.is_causal, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)

Expand All @@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops # noqa
from executorch.extension.llm.custom_ops import custom_ops

_replace_sdpa_with_custom_op(module)
return module
Expand Down
14 changes: 10 additions & 4 deletions examples/models/llama3_2_vision/vision_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
replace_tiled_token_positional_embedding,
)
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder

from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op

@dataclass
class VisionEncoderConfig:
Expand Down Expand Up @@ -47,7 +47,7 @@ class VisionEncoderConfig:


class FlamingoVisionEncoderModel(EagerModelBase):
def __init__(self, config: Optional[VisionEncoderConfig] = None):
def __init__(self, config: Optional[VisionEncoderConfig] = None, enable_source_transforms = True):
super().__init__()
if config is None:
config = demo_config
Expand All @@ -64,8 +64,8 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
max_num_tiles=config.max_num_tiles,
in_channels=config.in_channels,
)
self.model = replace_tile_positional_embedding(self.model)
self.model = replace_tiled_token_positional_embedding(self.model)
if enable_source_transforms:
self.source_transofrm()
self.image = torch.randn(
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
)
Expand All @@ -75,6 +75,12 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
self.aspect_ratio,
)

def source_transofrm(self):
self.model = replace_tile_positional_embedding(self.model)
self.model = replace_tiled_token_positional_embedding(self.model)
self.model = replace_mha_with_inference_mha(self.model)
self.model = replace_sdpa_with_custom_op(self.model)

def get_eager_model(self, **kwargs):
return self.model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,56 @@
FlamingoVisionEncoderModel,
)
from torch.testing import assert_close

from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig
from torch._inductor.package import package_aoti
from torch.nn.attention import SDPBackend
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)

class FlamingoVisionEncoderTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_flamingo_vision_encoder(self) -> None:
def test_flamingo_vision_encoder_et(self) -> None:
with torch.no_grad():
vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False)
encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs())
vision_model.source_transofrm()
encoder = vision_model.model
encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs())
assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs)

with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs()))

dynamic_quantizer = XNNPACKQuantizer()
operator_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
dynamic_quantizer.set_global(operator_config_dynamic)
prepare = prepare_pt2e(training_output.module(), dynamic_quantizer)
prepare(*vision_model.get_example_inputs())
convert = convert_pt2e(prepare)
DuplicateDynamicQuantChainPass()(convert)

export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())

edge = to_edge_transform_and_lower(export_output, partitioner=[
XnnpackDynamicallyQuantizedPartitioner(),
], compile_config=EdgeCompileConfig(_check_ir_validity=False))
edge.to_executorch()

def test_flamingo_vision_encoder_aoti(self) -> None:
model = FlamingoVisionEncoderModel()
encoder = model.model
eager_res = encoder.forward(*model.get_example_inputs())
Expand Down
24 changes: 23 additions & 1 deletion extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -310,7 +311,9 @@ def false_fn(y):
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
if input_pos is None:
input_pos = torch.tensor(0)
output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask)
return self.output_proj(output)


Expand Down Expand Up @@ -364,6 +367,7 @@ def forward(
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)


output = self._attention_fn(
q,
k,
Expand Down Expand Up @@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
"""
_replace_mha_with_inference_mha(module)
return module


def _replace_sdpa_with_custom_op(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPACustom(is_causal=child.is_causal),
)
else:
_replace_sdpa_with_custom_op(child)


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops
_replace_sdpa_with_custom_op(module)
return module
Loading