diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 59bfbe6f951..87a0f08d800 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -9,7 +9,7 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch @@ -22,20 +22,24 @@ class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], - dim: int, + kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None, + dim: int = -1, + is_causal = True, ): super().__init__() # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): + if kv_cache is None: + pass + elif not isinstance(kv_cache, QuantizedKVCache): self.kv_cache = kv_cache.to(torch.float) else: assert ( kv_cache.cache_fp_type == torch.float32 ), "Only float32 is supported for custom SDPA" self.dim = dim + self.is_causal = is_causal def forward( self, @@ -44,8 +48,8 @@ def forward( k: torch.Tensor, v: torch.Tensor, bsz, - seqlen, - mask, + seqlen = None, + mask = None, ): # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. @@ -54,9 +58,20 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - if hasattr(self.kv_cache, "quantized_cache_dtype"): + k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None + v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None + + if self.kv_cache is None: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos, + None, # Attention mask + 0, # dropout probability. Ignored by the code + self.is_causal, # is_causal + ) + elif isinstance(self.kv_cache, QuantizedKVCache): # updated quantize cache, scale and zero points # returns dequantized kv cache # Not most optimal. Optimizations to follow next @@ -68,7 +83,7 @@ def forward( input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code - True, # is_causal + self.is_causal, # is_causal ) else: output = torch.ops.llama.sdpa_with_kv_cache( @@ -81,7 +96,7 @@ def forward( seqlen, None, # Attention mask 0, # dropout probability. Ignored by the code - True, # is_causal + self.is_causal, # is_causal ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: - from executorch.extension.llm.custom_ops import custom_ops # noqa + from executorch.extension.llm.custom_ops import custom_ops _replace_sdpa_with_custom_op(module) return module diff --git a/examples/models/llama3_2_vision/vision_encoder/model.py b/examples/models/llama3_2_vision/vision_encoder/model.py index 79becd16205..438585bb2f8 100644 --- a/examples/models/llama3_2_vision/vision_encoder/model.py +++ b/examples/models/llama3_2_vision/vision_encoder/model.py @@ -15,7 +15,7 @@ replace_tiled_token_positional_embedding, ) from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder - +from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op @dataclass class VisionEncoderConfig: @@ -47,7 +47,7 @@ class VisionEncoderConfig: class FlamingoVisionEncoderModel(EagerModelBase): - def __init__(self, config: Optional[VisionEncoderConfig] = None): + def __init__(self, config: Optional[VisionEncoderConfig] = None, enable_source_transforms = True): super().__init__() if config is None: config = demo_config @@ -64,8 +64,8 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None): max_num_tiles=config.max_num_tiles, in_channels=config.in_channels, ) - self.model = replace_tile_positional_embedding(self.model) - self.model = replace_tiled_token_positional_embedding(self.model) + if enable_source_transforms: + self.source_transofrm() self.image = torch.randn( 1, 1, 4, 3, self.config.tile_size, self.config.tile_size ) @@ -75,6 +75,12 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None): self.aspect_ratio, ) + def source_transofrm(self): + self.model = replace_tile_positional_embedding(self.model) + self.model = replace_tiled_token_positional_embedding(self.model) + self.model = replace_mha_with_inference_mha(self.model) + self.model = replace_sdpa_with_custom_op(self.model) + def get_eager_model(self, **kwargs): return self.model diff --git a/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py b/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py index c0207968c56..428b3344760 100644 --- a/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py +++ b/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py @@ -16,13 +16,56 @@ FlamingoVisionEncoderModel, ) from torch.testing import assert_close - +from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig +from torch._inductor.package import package_aoti +from torch.nn.attention import SDPBackend +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + DuplicateDynamicQuantChainPass +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, +) class FlamingoVisionEncoderTest(unittest.TestCase): def setUp(self) -> None: super().setUp() - def test_flamingo_vision_encoder(self) -> None: + def test_flamingo_vision_encoder_et(self) -> None: + with torch.no_grad(): + vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False) + encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs()) + vision_model.source_transofrm() + encoder = vision_model.model + encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs()) + assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs) + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes()) + assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs())) + + dynamic_quantizer = XNNPACKQuantizer() + operator_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(operator_config_dynamic) + prepare = prepare_pt2e(training_output.module(), dynamic_quantizer) + prepare(*vision_model.get_example_inputs()) + convert = convert_pt2e(prepare) + DuplicateDynamicQuantChainPass()(convert) + + export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes()) + + edge = to_edge_transform_and_lower(export_output, partitioner=[ + XnnpackDynamicallyQuantizedPartitioner(), + ], compile_config=EdgeCompileConfig(_check_ir_validity=False)) + edge.to_executorch() + + def test_flamingo_vision_encoder_aoti(self) -> None: model = FlamingoVisionEncoderModel() encoder = model.model eager_res = encoder.forward(*model.get_example_inputs()) diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b42..232f106b38d 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -13,6 +13,7 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache +from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom logger = logging.getLogger(__name__) @@ -310,7 +311,9 @@ def false_fn(y): self.kv_cache.v_cache.copy_(v) self.kv_cache.cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + if input_pos is None: + input_pos = torch.tensor(0) + output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask) return self.output_proj(output) @@ -364,6 +367,7 @@ def forward( k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + output = self._attention_fn( q, k, @@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: """ _replace_mha_with_inference_mha(module) return module + + +def _replace_sdpa_with_custom_op(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPACustom(is_causal=child.is_causal), + ) + else: + _replace_sdpa_with_custom_op(child) + + +def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: + from executorch.extension.llm.custom_ops import custom_ops + _replace_sdpa_with_custom_op(module) + return module