diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0b1946f0cb6..8472e66b9c1 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -54,6 +54,7 @@ ) from .source_transformation.quantized_kv_cache import ( replace_kv_cache_with_quantized_kv_cache, + replace_torchtune_kv_cache_with_quantized_kv_cache, ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm @@ -65,10 +66,15 @@ replace_sdpa_with_coreml_sdpa, replace_sdpa_with_custom_op, replace_sdpa_with_flex_sdpa, + replace_sdpa_with_sdpa_only_custom_op, replace_sdpa_with_simple_sdpa, ) + +from .source_transformation.torchtune.attention import replace_mha_with_inference_mha + from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb + IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -237,7 +243,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--use_sdpa_with_kv_cache", default=False, action="store_true", - help="Whether to use sdpa_with_kv_cache update op when using kv cache", + help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.", ) parser.add_argument( "--disable_dynamic_shape", @@ -589,6 +595,18 @@ def _validate_args(args): if args.num_sharding > 0 and not args.qnn: raise ValueError("Model shard is only supported with qnn backend now.") + if args.model in TORCHTUNE_DEFINED_MODELS: + if args.use_sdpa_with_kv_cache: + if not args.use_kv_cache and not args.quantize_kv_cache: + raise ValueError( + f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled." + ) + if args.use_kv_cache: + if not args.quantize_kv_cache: + raise ValueError( + f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled." + ) + def _export_llama(args) -> LLMEdgeManager: # noqa: C901 _validate_args(args) @@ -892,6 +910,7 @@ def _load_llama_model( def _get_source_transforms( # noqa modelname: str, dtype_override: Optional[DType], args ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: + is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS transforms = [] if args.use_spin_quant: @@ -943,12 +962,29 @@ def _get_source_transforms( # noqa if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) + transforms.append(replace_mha_with_inference_mha) if args.use_sdpa_with_kv_cache: - transforms.append(replace_sdpa_with_custom_op) + if is_torchtune_model: + assert ( + args.use_kv_cache and args.quantize_kv_cache + ), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment." + transforms.append(replace_mha_with_inference_mha) + transforms.append(replace_sdpa_with_sdpa_only_custom_op) + else: + transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" - transforms.append(replace_kv_cache_with_quantized_kv_cache) + if is_torchtune_model: + transforms.append( + lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache( + module, + is_transposed=not args.use_sdpa_with_kv_cache, + enable_dynamic_shape=args.enable_dynamic_shape, + ) + ) + else: + transforms.append(replace_kv_cache_with_quantized_kv_cache) if args.use_kv_cache: if args.qnn: @@ -983,4 +1019,8 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) + print( + f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}" + ) + return transforms diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 6d92a45e800..e9399195c74 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -11,6 +11,7 @@ import torch.nn as nn from executorch.examples.models.llama.llama_transformer import KVCache from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchtune.modules.kv_cache import KVCache as TorchTuneKVCache """ @@ -207,8 +208,31 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType): kv_cache.enable_dynamic_shape, ) + @classmethod + def from_torchtune_float( + cls, + kv_cache, + cache_type: QuantizedCacheType, + is_transposed: bool, + enable_dynamic_shape: bool, + ): + cache_shape = kv_cache.k_cache.shape + if kv_cache.is_tranposed: + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + else: + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + return cls( + max_batch_size, + max_seq_length, + n_heads, + head_dim, + cache_type, + is_transposed, + enable_dynamic_shape, + ) + -def replace_kv_cache_with_quantized_kv_cache(module): +def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module: logging.warning( "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) @@ -222,3 +246,41 @@ def replace_kv_cache_with_quantized_kv_cache(module): else: replace_kv_cache_with_quantized_kv_cache(child) return module + + +def replace_torchtune_kv_cache_with_quantized_kv_cache( + module: nn.Module, is_transposed: bool, enable_dynamic_shape: bool +) -> nn.Module: + """ + Replace TorchTune KVCache with Executorch's quantized KVCache. + + Args: + is_transposed: whether q, k, and v are transposed. Should set to false when sdpa custom op source transform is enabled. + enable_dynamic_shape: whether dynamic shapes are enabled. + + Returns: + The passed in model. + """ + logging.warning( + "Replacing KVCache with QuantizedKVCache. This modifies the model in place." + ) + for name, child in module.named_children(): + if isinstance(child, TorchTuneKVCache): + cache_shape = child.k_cache.shape + if is_transposed: + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + else: + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + setattr( + module, + name, + QuantizedKVCache.from_torchtune_float( + child, + QuantizedCacheType.AffineAsymmetric, + is_transposed, + enable_dynamic_shape, + ), + ) + else: + replace_kv_cache_with_quantized_kv_cache(child) + return module diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f8362648f32..2ac33d82616 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -80,7 +80,7 @@ def forward( input_pos[0].item(), seqlen, None, # Attention mask - 0, # dropout probability. Ignored by the code + 0, # Dropout probability, ignored by the code True, # is_causal ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -105,6 +105,65 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: return module +class SDPAOnlyCustom(torch.nn.Module): + """ + Just the custom SDPA op, no KV cache update included. Can only be used + in conjunction with a quantized KV cache. + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor = None, + ): + # Custom op only supports float32 currently. Converting to/from float32 is + # faster than not having the op. + input_dtype = q.dtype + q = q.to(dtype=torch.float) + k = k.to(dtype=torch.float) + v = v.to(dtype=torch.float) + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + None, # Attention mask + 0, # Dropout probability, ignored by the code. + True, # is_causal + ) + return output.view(bsz, seqlen, -1).to(dtype=input_dtype) + + +def _replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + assert ( + child.kv_cache.cache_fp_type == torch.float32 + ), "Only float32 is supported for custom SDPA" + setattr( + module, + name, + SDPAOnlyCustom(), + ) + else: + _replace_sdpa_with_sdpa_only_custom_op(child) + + +def replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module) -> torch.nn.Module: + _replace_sdpa_with_sdpa_only_custom_op(module) + return module + + class SDPASimple(torch.nn.Module): def __init__(