From 3979fc8580f01f4a7d946f13de3a5b87b0a7cb56 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Mon, 6 Jan 2025 13:52:57 -0800 Subject: [PATCH 1/2] Changes to SDPA to support no kv cache export [ghstack-poisoned] --- .../llama/source_transformation/sdpa.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 59bfbe6f951..87a0f08d800 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -9,7 +9,7 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch @@ -22,20 +22,24 @@ class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], - dim: int, + kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None, + dim: int = -1, + is_causal = True, ): super().__init__() # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): + if kv_cache is None: + pass + elif not isinstance(kv_cache, QuantizedKVCache): self.kv_cache = kv_cache.to(torch.float) else: assert ( kv_cache.cache_fp_type == torch.float32 ), "Only float32 is supported for custom SDPA" self.dim = dim + self.is_causal = is_causal def forward( self, @@ -44,8 +48,8 @@ def forward( k: torch.Tensor, v: torch.Tensor, bsz, - seqlen, - mask, + seqlen = None, + mask = None, ): # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. @@ -54,9 +58,20 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - if hasattr(self.kv_cache, "quantized_cache_dtype"): + k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None + v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None + + if self.kv_cache is None: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos, + None, # Attention mask + 0, # dropout probability. Ignored by the code + self.is_causal, # is_causal + ) + elif isinstance(self.kv_cache, QuantizedKVCache): # updated quantize cache, scale and zero points # returns dequantized kv cache # Not most optimal. Optimizations to follow next @@ -68,7 +83,7 @@ def forward( input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code - True, # is_causal + self.is_causal, # is_causal ) else: output = torch.ops.llama.sdpa_with_kv_cache( @@ -81,7 +96,7 @@ def forward( seqlen, None, # Attention mask 0, # dropout probability. Ignored by the code - True, # is_causal + self.is_causal, # is_causal ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: - from executorch.extension.llm.custom_ops import custom_ops # noqa + from executorch.extension.llm.custom_ops import custom_ops _replace_sdpa_with_custom_op(module) return module From 84e7573cda0a911477980834326e1a44fc2826d2 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Mon, 6 Jan 2025 14:35:15 -0800 Subject: [PATCH 2/2] Update on "Changes to SDPA to support no kv cache export" Differential Revision: [D67878163](https://our.internmc.facebook.com/intern/diff/D67878163) [ghstack-poisoned] --- examples/models/llama/source_transformation/sdpa.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 87a0f08d800..eff6ee5aec7 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -9,7 +9,7 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple, Union, Optional +from typing import Optional, Tuple, Union import torch @@ -24,7 +24,7 @@ def __init__( self, kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None, dim: int = -1, - is_causal = True, + is_causal=True, ): super().__init__() # Custom op only supports float32 currently. Converting to/from float32 is @@ -48,8 +48,8 @@ def forward( k: torch.Tensor, v: torch.Tensor, bsz, - seqlen = None, - mask = None, + seqlen=None, + mask=None, ): # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op.