diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 550a09e4c6f..9bb881ce8eb 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -112,6 +112,8 @@ fi if [[ "${MODE}" =~ .*quantize_kv.* ]]; then QUANTIZE_KV_CACHE=ON + # quantize_kv cache transform uses custom kv cache update op + CUSTOM=ON else QUANTIZE_KV_CACHE=OFF fi diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index ba397273b62..33237f3bebe 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -374,7 +374,7 @@ def get_custom_quant_ios_dtype( """ This function is specific for llama inputs and outputs """ - if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name: + if node.op == "placeholder" and "attention_kv_cache_past_" in node.name: return kv_dtype # Tag index put node before copy node, because copy is a skipped node in qnn diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 69980990cfd..d0860700362 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -665,6 +665,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # export_to_edge builder_exported = _prepare_for_llama_export(args).export() + builder_exported.run_canonical_optimizations() + if args.export_only: exit() diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index aaef3cd9804..d5661ae4004 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -232,22 +232,16 @@ def __init__( max_seq_length: int, n_heads: int, head_dim: int, - transpose_cache: bool, enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() self.max_seq_length = max_seq_length - self.is_transposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads self.head_dim = head_dim - self.transpose_cache = transpose_cache self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") @@ -259,12 +253,12 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache + # input_pos: [S], k_val: [B, H, S, D] if self.enable_dynamic_shape: start_pos = input_pos[0].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 + dim_to_slice = 2 seq_length = k_val.size(dim_to_slice) # Replace the entry in the cache for this token # The following lines are equivalent to: @@ -283,12 +277,8 @@ def update( else: k_out = self.k_cache v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val return k_out, v_out @@ -296,7 +286,6 @@ def update( class SDPA(nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, @@ -304,7 +293,6 @@ def __init__( enable_dynamic_shape: bool, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -314,18 +302,13 @@ def __init__( def forward( self, input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) + v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) bsz, seqlen, mask: torch.Tensor, ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) if self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) @@ -336,6 +319,8 @@ def forward( else: attn_mask = mask[None, None, input_pos] + # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention + # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) @@ -383,11 +368,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( - kv_cache=self.kv_cache, dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, @@ -414,15 +397,16 @@ def forward( # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if self.use_kv_cache: assert input_pos is not None + k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) return self.wo(output) - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - # grouped multiquery attention: expand out keys and values k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index b534a98e078..5b3bfba9add 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -111,7 +111,6 @@ def __init__( self, n_heads: int, head_dim: int, - transpose_cache: bool, enable_dynamic_shape: bool, rope: RopeWithAttentionSink, window_size: int, @@ -125,7 +124,6 @@ def __init__( max_seq_length=window_size + sink_size, n_heads=n_heads, head_dim=head_dim, - transpose_cache=transpose_cache, enable_dynamic_shape=enable_dynamic_shape, dtype=dtype, ) @@ -161,28 +159,17 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: input_pos_item + self.position_shift - self.sink_size - num_to_evict ) num_empty_space = self.window_size - num_to_keep - dim_to_slice = 2 if self.transpose_cache else 1 + dim_to_slice = 2 k_to_keep = self.k_cache.narrow( dim_to_slice, self.sink_size + num_to_evict, # pyre-ignore [6] num_to_keep, # pyre-ignore [6] ) - if self.transpose_cache: - k_to_keep = self.rope.rerotate_k( - k=k_to_keep.transpose(1, 2), - original_position=( # pyre-ignore [6] - self.sink_size + num_to_evict - ), - new_position=self.sink_size, - ).transpose(1, 2) - else: - k_to_keep = self.rope.rerotate_k( - k=k_to_keep, - original_position=( # pyre-ignore [6] - self.sink_size + num_to_evict - ), - new_position=self.sink_size, - ) + k_to_keep = self.rope.rerotate_k( + k=k_to_keep.transpose(1, 2), + original_position=(self.sink_size + num_to_evict), # pyre-ignore [6] + new_position=self.sink_size, + ).transpose(1, 2) self.k_cache = torch.cat( [ self.k_cache.narrow(dim_to_slice, 0, self.sink_size), @@ -278,7 +265,6 @@ def _replace_attention( kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, head_dim=kv_cache.head_dim, - transpose_cache=kv_cache.transpose_cache, enable_dynamic_shape=kv_cache.enable_dynamic_shape, rope=rope_with_attention_sink, max_batch_size=kv_cache.max_batch_size, @@ -288,7 +274,6 @@ def _replace_attention( dtype=kv_cache.k_cache.dtype, ) child_module.kv_cache = kv_cache_with_attention_sink - child_module.SDPA.kv_cache = kv_cache_with_attention_sink child_module.forward = types.MethodType( # pyre-ignore attention_sink_forward, child_module ) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index d8ac99656f1..90ec9879e54 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -37,8 +37,7 @@ def __init__( n_heads, head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, - tranposed=False, - enable_dynamic_shape=False, + use_custom_update_cache_op: bool = False, ): super().__init__() if cache_type not in ( @@ -50,16 +49,11 @@ def __init__( ) # For now supporting int8 only + self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - self.is_transposed = tranposed - self.enable_dynamic_shape = enable_dynamic_shape - if self.is_transposed: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - scale_shape = (max_batch_size, n_heads, max_seq_length, 1) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - scale_shape = (max_batch_size, max_seq_length, n_heads, 1) + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_seq_length, n_heads, 1) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) @@ -98,60 +92,20 @@ def _quantize(self, value): return quantized_value, scales, zero_points def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) # quantize current k_val and store it in the cache quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) - if self.is_transposed: - # We cannot use update_cache op at the moment - # if the cache is transposed - # Also note that we shold not need separate paths - # for dynamic shape vs ! - # Only reason it is done this way is to accommodate - # for lowering pains of backends that work better - # with index_put op. - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - self.k_cache[:, :, input_pos] = quantized_k_val - self.k_cache_scales[:, :, input_pos] = k_scales - self.k_cache_zero_points[:, :, input_pos] = k_zero_points - self.v_cache[:, :, input_pos] = quantized_v_val - self.v_cache_scales[:, :, input_pos] = v_scales - self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - # Right now using custom ops on this path. - # In future we can update custom op to handle transposed cache - # as well. - # Note that we may have to revert this change if other ET - # backends such as QNN want to use quantized cache, with dynamic shape, - # instead of quantizing on their own. - # But until this opting for code simplicity + if self.use_custom_update_cache_op: start_pos = input_pos[0].item() _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) @@ -163,6 +117,13 @@ def update(self, input_pos, k_val, v_val): _ = torch.ops.llama.update_cache( v_zero_points, self.v_cache_zero_points, start_pos ) + else: + self.k_cache[:, input_pos] = quantized_k_val + self.k_cache_scales[:, input_pos] = k_scales + self.k_cache_zero_points[:, input_pos] = k_zero_points + self.v_cache[:, input_pos] = quantized_v_val + self.v_cache_scales[:, input_pos] = v_scales + self.v_cache_zero_points[:, input_pos] = v_zero_points k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, @@ -183,42 +144,34 @@ def update(self, input_pos, k_val, v_val): self.cache_fp_type, ) - if self.is_transposed: - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k.copy_(k_val) - narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v.copy_(v_val) - else: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - start_pos = input_pos[0].item() + start_pos = input_pos[0].item() + if self.use_custom_update_cache_op: _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + else: + k_out[:, input_pos] = k_val + v_out[:, input_pos] = v_val - return k_out, v_out + return k_out.transpose(1, 2), v_out.transpose(1, 2) @classmethod - def from_float(cls, kv_cache, cache_type: QuantizedCacheType): - cache_shape = kv_cache.k_cache.shape - if kv_cache.is_transposed: - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape - else: - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + def from_float( + cls, + kv_cache, + cache_type: QuantizedCacheType, + use_custom_update_cache_op: bool = False, + ): + max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape + if isinstance(kv_cache, CustomKVCache): + # If replacing custom kv cache, then the shape is [B, S, H, D] + max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape return cls( max_batch_size, max_seq_length, n_heads, head_dim, cache_type, - kv_cache.is_transposed, - kv_cache.enable_dynamic_shape, + use_custom_update_cache_op, ) @@ -254,11 +207,15 @@ def replace_kv_cache_with_quantized_kv_cache(module): "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) for name, child in module.named_children(): - if isinstance(child, KVCache): + if isinstance(child, KVCache) or isinstance(child, CustomKVCache): setattr( module, name, - QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric), + QuantizedKVCache.from_float( + child, + QuantizedCacheType.AffineAsymmetric, + use_custom_update_cache_op=True, + ), ) else: replace_kv_cache_with_quantized_kv_cache(child) @@ -291,11 +248,16 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, S, H, D] + # input_pos: [S], k_val: [B, H, S, D] + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) - return self.k_cache, self.v_cache + return ( + self.k_cache.transpose(1, 2), + self.v_cache.transpose(1, 2), + ) def replace_kv_cache_with_custom_kv_cache(module): @@ -313,10 +275,7 @@ def replace_kv_cache_with_custom_kv_cache(module): if isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype - assert ( - child.is_transposed is False - ), "CustomKVCache does not support transposed cache" - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape setattr( module, name, diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 4d4b3bf7f56..6a54d6a119f 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -9,32 +9,19 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple, Union +from typing import Tuple import torch from executorch.examples.models.llama.llama_transformer import KVCache, SDPA -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( - QuantizedKVCache, -) class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], dim: int, ): super().__init__() - # Custom op only supports float32 currently. Converting to/from float32 is - # faster than not having the op. - self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): - self.kv_cache = kv_cache.to(torch.float) - else: - assert ( - kv_cache.cache_fp_type == torch.float32 - ), "Only float32 is supported for custom SDPA" self.dim = dim def forward( @@ -47,6 +34,10 @@ def forward( seqlen, mask, ): + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. input_dtype = q.dtype @@ -54,13 +45,10 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - k_cache, v_cache = self.kv_cache.update(input_pos, k, v) output = torch.ops.llama.custom_sdpa( q, - k_cache, - v_cache, + k, + v, input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code @@ -75,7 +63,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): setattr( module, name, - SDPACustom(child.kv_cache, child.dim), + SDPACustom(child.dim), ) else: _replace_sdpa_with_custom_op(child) @@ -91,13 +79,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: class SDPASimple(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -112,11 +98,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) @@ -150,12 +131,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class SDPAFlex(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.n_rep = n_rep @@ -169,9 +148,10 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - k, v = self.kv_cache.update(input_pos, k, v) + """ + q: (bs, n_heads, seqlen, head_dim) + k, v: (bs, n_local_heads, seqlen, head_dim) + """ k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) attn_mask = mask[input_pos] @@ -191,7 +171,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): setattr( module, name, - SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPASimple(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_simple_sdpa(child) @@ -204,7 +184,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): setattr( module, name, - SDPAFlex(child.kv_cache, child.dim, child.n_rep), + SDPAFlex(child.dim, child.n_rep), ) else: replace_sdpa_with_flex_sdpa(child) @@ -236,13 +216,11 @@ class SDPACoreML(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -257,11 +235,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] if self.n_rep > 1: @@ -279,7 +252,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module): setattr( module, name, - SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPACoreML(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_coreml_sdpa(child) @@ -366,6 +339,9 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + # can we combine this with KVCacheCoreML? + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val) diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 4ffecf1e9c3..4dd522dff2f 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -120,21 +120,18 @@ def test_rotate(self, original_position, new_position): class KVCacheWithAttentionSinkTest(unittest.TestCase): _single_evict_test_cases = [ - [False, 4, 1], - [True, 4, 1], + [4, 1], ] _batch_evict_test_cases = [ - [False, 4, 8], - [True, 4, 8], + [4, 8], ] _sliding_window_test_cases = [ - [False, 0, 1], - [True, 0, 1], + [0, 1], ] - def _init_cache(self, transpose_cache, sink_size, eviction_batch_size): + def _init_cache(self, sink_size, eviction_batch_size): self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, @@ -149,7 +146,6 @@ def _init_cache(self, transpose_cache, sink_size, eviction_batch_size): self.kv_cache = KVCacheWithAttentionSink( n_heads=self.params.n_heads, head_dim=self.params.head_dim, - transpose_cache=transpose_cache, enable_dynamic_shape=self.params.enable_dynamic_shape, rope=self.rope_with_attention_sink, max_batch_size=self.max_batch_size, @@ -159,94 +155,49 @@ def _init_cache(self, transpose_cache, sink_size, eviction_batch_size): dtype=self.dtype, ) - def _rand_kv_with_length(self, transpose_cache, seq_len): + def _rand_kv_with_length(self, seq_len): size = ( - ( - self.max_batch_size, - seq_len, - self.params.n_heads, - self.params.head_dim, - ) - if not transpose_cache - else ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - ) - if not transpose_cache: - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) - else: - k = torch.rand( - *size, - dtype=self.dtype, - ) - v = torch.rand( - *size, - dtype=self.dtype, - ) + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + k = torch.rand( + *size, + dtype=self.dtype, + ) + v = torch.rand( + *size, + dtype=self.dtype, + ) return k, v - def _zero_kv_with_length(self, transpose_cache, seq_len): + def _zero_kv_with_length(self, seq_len): size = ( - ( - self.max_batch_size, - seq_len, - self.params.n_heads, - self.params.head_dim, - ) - if not transpose_cache - else ( - self.max_batch_size, - self.params.n_heads, - seq_len, - self.params.head_dim, - ) - ) - if not transpose_cache: - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) - else: - k = torch.zeros( - *size, - dtype=self.dtype, - ) - v = torch.zeros( - *size, - dtype=self.dtype, - ) + self.max_batch_size, + self.params.n_heads, + seq_len, + self.params.head_dim, + ) + k = torch.zeros( + *size, + dtype=self.dtype, + ) + v = torch.zeros( + *size, + dtype=self.dtype, + ) return k, v - def _get_dim_to_slice(self, transpose_cache): - return 2 if transpose_cache else 1 + def _get_dim_to_slice(self): + return 2 - def _get_expected_rotated_k( - self, transpose_cache, k, original_position, new_position - ): - if transpose_cache: - return self.rope_with_attention_sink.rerotate_k( - k=k.transpose(1, 2), - original_position=original_position, - new_position=new_position, - ).transpose(1, 2) - else: - return self.rope_with_attention_sink.rerotate_k( - k=k, original_position=original_position, new_position=new_position - ) + def _get_expected_rotated_k(self, k, original_position, new_position): + return self.rope_with_attention_sink.rerotate_k( + k=k.transpose(1, 2), + original_position=original_position, + new_position=new_position, + ).transpose(1, 2) def setUp(self): torch.manual_seed(42) @@ -257,16 +208,14 @@ def setUp(self): @parameterized.expand( _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases ) - def test_evict_empty_cache(self, transpose_cache, sink_size, eviction_batch_size): - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + def test_evict_empty_cache(self, sink_size, eviction_batch_size): + self._init_cache(sink_size, eviction_batch_size) # KV cache is empty, evict does nothing input_pos = torch.tensor([0], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - expected_k, expected_v = self._zero_kv_with_length( - transpose_cache, self.window_size + sink_size - ) + expected_k, expected_v = self._zero_kv_with_length(self.window_size + sink_size) torch.testing.assert_close(self.kv_cache.k_cache, expected_k) torch.testing.assert_close(self.kv_cache.v_cache, expected_v) @@ -274,23 +223,21 @@ def test_evict_empty_cache(self, transpose_cache, sink_size, eviction_batch_size @parameterized.expand( _single_evict_test_cases + _batch_evict_test_cases + _sliding_window_test_cases ) - def test_evict_without_shift(self, transpose_cache, sink_size, eviction_batch_size): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + def test_evict_without_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = 2 - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has enough spaces for new tokens, no shift input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 10) + k, v = self._rand_kv_with_length(10) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([10], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 1) == 0 - zero_k, zero_v = self._zero_kv_with_length( - transpose_cache, self.window_size + sink_size - 10 - ) + zero_k, zero_v = self._zero_kv_with_length(self.window_size + sink_size - 10) expected_k = torch.cat( [ @@ -311,34 +258,30 @@ def test_evict_without_shift(self, transpose_cache, sink_size, eviction_batch_si torch.testing.assert_close(self.kv_cache.v_cache, expected_v) @parameterized.expand(_single_evict_test_cases) - def test_evict_with_some_shift( - self, transpose_cache, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + def test_evict_with_some_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has some spaces for new tokens but not all, shift some tokens input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 5) + k, v = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + k1, v1 = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k1, v1) input_pos = torch.tensor([10], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 24) == -2 - zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 24) + zero_k, zero_v = self._zero_kv_with_length(24) expected_k = torch.cat( [ k.narrow(dimension_to_slice, 0, sink_size), - self._get_expected_rotated_k( - transpose_cache, k1.narrow(dimension_to_slice, 1, 4), 6, 4 - ), + self._get_expected_rotated_k(k1.narrow(dimension_to_slice, 1, 4), 6, 4), zero_k, ], dim=dimension_to_slice, @@ -356,33 +299,31 @@ def test_evict_with_some_shift( torch.testing.assert_close(self.kv_cache.v_cache, expected_v) @parameterized.expand(_single_evict_test_cases) - def test_evict_with_all_shift( - self, transpose_cache, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + def test_evict_with_all_shift(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has no spaces for new tokens, shift all tokens input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 5) + k, v = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(transpose_cache, 27) + k1, v1 = self._rand_kv_with_length(27) self.kv_cache.update(input_pos, k1, v1) input_pos = torch.tensor([32], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + zero_k, zero_v = self._zero_kv_with_length(6) expected_k = torch.cat( [ k.narrow(dimension_to_slice, 0, sink_size), self._get_expected_rotated_k( - transpose_cache, k1.narrow(dimension_to_slice, 5, 22), 10, 4 + k1.narrow(dimension_to_slice, 5, 22), 10, 4 ), zero_k, ], @@ -402,33 +343,31 @@ def test_evict_with_all_shift( @parameterized.expand(_sliding_window_test_cases) def test_evict_with_some_shift_for_sliding_window( - self, transpose_cache, sink_size, eviction_batch_size + self, sink_size, eviction_batch_size ): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + dimension_to_slice = self._get_dim_to_slice() - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has some spaces for new tokens but not all, shift some tokens input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 5) + k, v = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(transpose_cache, 5) + k1, v1 = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k1, v1) input_pos = torch.tensor([10], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 20) == -2 - zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 20) + zero_k, zero_v = self._zero_kv_with_length(20) expected_k = torch.cat( [ - self._get_expected_rotated_k( - transpose_cache, k.narrow(dimension_to_slice, 2, 3), 2, 0 - ), - self._get_expected_rotated_k(transpose_cache, k1, 5, 3), + self._get_expected_rotated_k(k.narrow(dimension_to_slice, 2, 3), 2, 0), + self._get_expected_rotated_k(k1, 5, 3), zero_k, ], dim=dimension_to_slice, @@ -447,31 +386,31 @@ def test_evict_with_some_shift_for_sliding_window( @parameterized.expand(_sliding_window_test_cases) def test_evict_with_all_shift_for_sliding_window( - self, transpose_cache, sink_size, eviction_batch_size + self, sink_size, eviction_batch_size ): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + dimension_to_slice = self._get_dim_to_slice() - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has no spaces for new tokens, shift all tokens input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 5) + k, v = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(transpose_cache, 23) + k1, v1 = self._rand_kv_with_length(23) self.kv_cache.update(input_pos, k1, v1) input_pos = torch.tensor([28], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 6) == -6 - zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 6) + zero_k, zero_v = self._zero_kv_with_length(6) expected_k = torch.cat( [ self._get_expected_rotated_k( - transpose_cache, k1.narrow(dimension_to_slice, 1, 22), 6, 0 + k1.narrow(dimension_to_slice, 1, 22), 6, 0 ), zero_k, ], @@ -489,33 +428,31 @@ def test_evict_with_all_shift_for_sliding_window( torch.testing.assert_close(self.kv_cache.v_cache, expected_v) @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_seq_len( - self, transpose_cache, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + def test_batch_evict_with_seq_len(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has some spaces for new tokens but not all, shift some tokens input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 5) + k, v = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + k1, v1 = self._rand_kv_with_length(25) self.kv_cache.update(input_pos, k1, v1) input_pos = torch.tensor([30], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 12) == -10 - zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 12) + zero_k, zero_v = self._zero_kv_with_length(12) expected_k = torch.cat( [ k.narrow(dimension_to_slice, 0, sink_size), self._get_expected_rotated_k( - transpose_cache, k1.narrow(dimension_to_slice, 9, 16), 14, 4 + k1.narrow(dimension_to_slice, 9, 16), 14, 4 ), zero_k, ], @@ -534,33 +471,31 @@ def test_batch_evict_with_seq_len( torch.testing.assert_close(self.kv_cache.v_cache, expected_v) @parameterized.expand(_batch_evict_test_cases) - def test_batch_evict_with_batch_size( - self, transpose_cache, sink_size, eviction_batch_size - ): - dimension_to_slice = self._get_dim_to_slice(transpose_cache) + def test_batch_evict_with_batch_size(self, sink_size, eviction_batch_size): + dimension_to_slice = self._get_dim_to_slice() - self._init_cache(transpose_cache, sink_size, eviction_batch_size) + self._init_cache(sink_size, eviction_batch_size) # KV cache has no spaces for new tokens, shift all tokens input_pos = torch.tensor([0], dtype=torch.int32) - k, v = self._rand_kv_with_length(transpose_cache, 5) + k, v = self._rand_kv_with_length(5) self.kv_cache.update(input_pos, k, v) input_pos = torch.tensor([5], dtype=torch.int32) - k1, v1 = self._rand_kv_with_length(transpose_cache, 25) + k1, v1 = self._rand_kv_with_length(25) self.kv_cache.update(input_pos, k1, v1) input_pos = torch.tensor([30], dtype=torch.int32) assert self.kv_cache.evict_tokens(input_pos, 6) == -8 - zero_k, zero_v = self._zero_kv_with_length(transpose_cache, 10) + zero_k, zero_v = self._zero_kv_with_length(10) expected_k = torch.cat( [ k.narrow(dimension_to_slice, 0, sink_size), self._get_expected_rotated_k( - transpose_cache, k1.narrow(dimension_to_slice, 7, 18), 12, 4 + k1.narrow(dimension_to_slice, 7, 18), 12, 4 ), zero_k, ], diff --git a/examples/models/llama/source_transformation/test_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_quantized_kv_cache.py index e5ade3dd128..67ebbc7b3fe 100644 --- a/examples/models/llama/source_transformation/test_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_quantized_kv_cache.py @@ -17,23 +17,18 @@ class QuantizedKVCacheTest(unittest.TestCase): - def _init_cache(self): self.kv_cache = KVCache( self.max_batch_size, self.max_seq_len, self.n_kv_heads, self.head_dim, - self.transpose_kv_cache, self.enable_dynamic_shape, dtype=self.dtype, ) def _init_kv(self): - if self.transpose_kv_cache: - shape = (1, self.n_kv_heads, self.seq_len, self.head_dim) - else: - shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) + shape = (1, self.n_kv_heads, self.seq_len, self.head_dim) k = torch.rand(shape, dtype=self.dtype) v = torch.rand(shape, dtype=self.dtype) return k, v @@ -45,29 +40,29 @@ def setUp(self): self.n_kv_heads = 8 self.head_dim = 17 self.enable_dynamic_shape = False - self.transpose_kv_cache = False self.dtype = torch.float32 - def _test_simple_update_fetch(self, is_transposed=False, is_dynamic_shape=False): - self.transpose_kv_cache = is_transposed + def _test_simple_update_fetch( + self, is_dynamic_shape=False, use_custom_update_cache_op=False + ): self.enable_dynamic_shape = is_dynamic_shape input_pos = torch.tensor([0, 1, 2]) self.seq_len = input_pos.size(0) self._init_cache() k, v = self._init_kv() quantized_kv_cache = QuantizedKVCache.from_float( - self.kv_cache, QuantizedCacheType.AffineAsymmetric + self.kv_cache, + QuantizedCacheType.AffineAsymmetric, + use_custom_update_cache_op, ) updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v) - updated_dequantized_k_cache, updated_dequantized_v_cache = ( - quantized_kv_cache.update(input_pos, k, v) - ) + ( + updated_dequantized_k_cache, + updated_dequantized_v_cache, + ) = quantized_kv_cache.update(input_pos, k, v) def index(t, input_pos): - if self.transpose_kv_cache: - return t[:, :, input_pos, :] - else: - return t[:, input_pos, :, :] + return t[:, :, input_pos, :] sliced_k_cache = index(updated_k_cache, input_pos) sliced_v_cache = index(updated_v_cache, input_pos) @@ -93,9 +88,10 @@ def index(t, input_pos): k, v = self._init_kv() pos_to_check = torch.tensor([0, 1, 2, 3]) updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v) - updated_dequantized_k_cache, updated_dequantized_v_cache = ( - quantized_kv_cache.update(input_pos, k, v) - ) + ( + updated_dequantized_k_cache, + updated_dequantized_v_cache, + ) = quantized_kv_cache.update(input_pos, k, v) sliced_k_cache = index(updated_k_cache, pos_to_check) sliced_v_cache = index(updated_v_cache, pos_to_check) @@ -115,14 +111,16 @@ def index(t, input_pos): atol=1e-02, ) - def test_simple_update_fetch_not_transposed(self): + def test_simple_update_fetch(self): self._test_simple_update_fetch() - def test_simple_update_fetch_not_transposed_dynamic_shape(self): - self._test_simple_update_fetch(is_dynamic_shape=True) + def test_simple_update_fetch_use_custom_op(self): + self._test_simple_update_fetch(use_custom_update_cache_op=True) - def test_simple_update_fetch_transposed(self): - self._test_simple_update_fetch(is_transposed=True) + def test_simple_update_fetch_dynamic_shape(self): + self._test_simple_update_fetch(is_dynamic_shape=True) - def test_simple_update_fetch_transposed_dynamic_shape(self): - self._test_simple_update_fetch(is_transposed=True, is_dynamic_shape=True) + def test_simple_update_fetch_dynamic_shape_use_custom_op(self): + self._test_simple_update_fetch( + is_dynamic_shape=True, use_custom_update_cache_op=True + ) diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 57c36dabf9b..0081c5072c9 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -26,12 +26,11 @@ def _init_cache(self): self.max_seq_len, self.n_kv_heads, self.head_dim, - False, self.enable_dynamic_shape, dtype=self.dtype, ) self.quantized_kv_cache = QuantizedKVCache.from_float( - self.kv_cache, QuantizedCacheType.AffineAsymmetric + self.kv_cache, QuantizedCacheType.AffineAsymmetric, True ) # Need this because first test actually has seq_len > 1 # and vanilla kvcache cannot handle seq_len > 1, due to @@ -48,8 +47,8 @@ def _init_cache(self): ) def _init_kv(self): - kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) - q_shape = (1, self.seq_len, self.n_heads, self.head_dim) + kv_shape = (1, self.n_kv_heads, self.seq_len, self.head_dim) + q_shape = (1, self.n_heads, self.seq_len, self.head_dim) q = torch.rand(q_shape, dtype=self.dtype) k = torch.rand(kv_shape, dtype=self.dtype) v = torch.rand(kv_shape, dtype=self.dtype) @@ -71,10 +70,12 @@ def test_simple(self, is_dynamic_shape=False): input_pos = torch.tensor([0], dtype=torch.int64) self.seq_len = 3 self._init_cache() - q, k, v = self._init_kv() - self.float_sdpa = SDPACustom(self.custom_kv_cache, self.dim) - self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim) + q, k_val, v_val = self._init_kv() + self.float_sdpa = SDPACustom(self.dim) + self.quantized_sdpa = SDPACustom(self.dim) + k, v = self.custom_kv_cache.update(input_pos, k_val, v_val) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val) quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) torch.testing.assert_close( float_out, @@ -83,8 +84,10 @@ def test_simple(self, is_dynamic_shape=False): input_pos = torch.tensor([3], dtype=torch.int64) self.seq_len = 1 - q, k, v = self._init_kv() + q, k_val, v_val = self._init_kv() + k, v = self.custom_kv_cache.update(input_pos, k_val, v_val) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val) quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) torch.testing.assert_close( float_out, diff --git a/examples/models/llama/tests/test_simple_sdpa.py b/examples/models/llama/tests/test_simple_sdpa.py index 6e0c3919602..4088165c71f 100644 --- a/examples/models/llama/tests/test_simple_sdpa.py +++ b/examples/models/llama/tests/test_simple_sdpa.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import unittest import torch @@ -29,11 +28,9 @@ def test_simple_sdpa(self): max_seq_length=max_seq_length, n_heads=n_heads, head_dim=head_dim, - transpose_cache=True, enable_dynamic_shape=False, ) sdpa = SDPA( - kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep, @@ -45,6 +42,11 @@ def test_simple_sdpa(self): key = torch.randn(1, 1, n_local_heads, head_dim) value = torch.randn(1, 1, n_local_heads, head_dim) mask = torch.randn(max_seq_length, max_seq_length) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + key, value = kv_cache.update(input_pos, key, value) + sdpa_output = sdpa( input_pos, query, @@ -55,9 +57,7 @@ def test_simple_sdpa(self): mask=mask, ) - simple_sdpa = SDPASimple( - kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep - ) + simple_sdpa = SDPASimple(dim=dim, head_dim=head_dim, n_rep=n_rep) simple_sdpa_output = simple_sdpa( input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask ) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index dabb07e61ce..c4834aee256 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -20,6 +20,9 @@ EmbeddingQuantHandler, get_quant_weight_transform, ) +from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) @@ -101,6 +104,7 @@ def forward(self, input_pos, embeddings): _, quantizers, _ = get_quantizer_and_quant_params(args) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: + source_transforms.append(replace_kv_cache_with_custom_kv_cache) source_transforms.append(replace_sdpa_with_custom_op) source_transforms.append(quant_transform) manager = ( diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index a24249d9534..68a9e59e0ce 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -14,6 +14,9 @@ import torch from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) @@ -62,6 +65,7 @@ def __init__( self.text_model = Transformer(self.text_model_args) # use custom op for SDPA. if use_sdpa_with_kv_cache_op: + self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model) self.text_model = replace_sdpa_with_custom_op(self.text_model) # load state dict self.text_model.load_state_dict( diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index df991e3fd17..64a8fcc2887 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -106,7 +106,7 @@ struct type_convert< torch::executor::Tensor>>> final { explicit type_convert(ATensor value) - : value_(value), + : value_(value.contiguous()), converted_(from_blob( value_.mutable_data_ptr(), {value_.sizes().begin(), value_.sizes().end()}, @@ -117,7 +117,7 @@ struct type_convert< } private: - ATensor value_; + typename remove_const_ref::type value_; TensorPtr converted_; }; diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index b957a580787..581979afd9f 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -121,7 +121,7 @@ at::Tensor custom_sdpa_aten( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { - auto output = at::empty_like(q); + auto output = at::empty(q.sizes()); WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); return output; diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index bcfc130add2..121cef7ce4f 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( name = "export_lib", srcs = [ "builder.py", + "export_passes.py", "partitioner_lib.py", "quantizer_lib.py", ], diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 64e1bcdab31..be6977b639c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -33,6 +33,8 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import export_to_edge, save_pte_program + +from executorch.extension.llm.export.export_passes import RemoveRedundantPermutes from executorch.extension.llm.tokenizer.utils import get_tokenizer from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer @@ -111,6 +113,7 @@ def __init__( self.calibration_seq_length = calibration_seq_length self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path + self.canonical_passes = [RemoveRedundantPermutes()] def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -223,6 +226,17 @@ def export(self) -> "LLMEdgeManager": return self + def run_canonical_optimizations(self): + """ + Run canonical optimizations (at the moment removing redundant permutes) on the model. + """ + assert self.pre_autograd_graph_module is not None, "Please run export() first" + for pass_instance in self.canonical_passes: + logging.info(f"Running canonical pass: {pass_instance.__class__.__name__}") + res = pass_instance(self.pre_autograd_graph_module) + assert res.graph_module is not None, "Pass returned None" + self.pre_autograd_graph_module = res.graph_module + def pt2e_calibrate( self, prepared_module, diff --git a/extension/llm/export/export_passes.py b/extension/llm/export/export_passes.py new file mode 100644 index 00000000000..942de095805 --- /dev/null +++ b/extension/llm/export/export_passes.py @@ -0,0 +1,97 @@ +import torch + +from executorch.exir.pass_base import ExportPass +from torch._subclasses import FakeTensor +from torch.fx.passes.infra.pass_base import PassResult + + +def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int): + """ + Normalize the dimensions of a tensor. + """ + assert tensor is not None, "Tensor is None" + ndim = tensor.ndim + if dim_0 < 0: + dim_0 = ndim + dim_0 + if dim_1 < 0: + dim_1 = ndim + dim_1 + assert dim_0 < ndim and dim_1 < ndim, f"Invalid dimensions: {dim_0}, {dim_1}" + return dim_0, dim_1 + + +class RemoveRedundantPermutes(ExportPass): + """ + This pass removes redundant transpose nodes in the graph. + It checks if the next node is also a transpose node and if the two transpose nodes undo each other. + For example, if the graph has the following nodes: + + node1 = torch.ops.aten.transpose.int(x, 0, 1) + node2 = torch.ops.aten.transpose.int(node1, 0, 1) + + Then node2's use can be replaced by x + + It will also check for permute nodes + node1 = torch.ops.aten.permute(x, [0, 2, 1]) + node2 = torch.ops.aten.permute(node1, [0, 2, 1]) + + Then also node2's use can be replaced by x + + NB: Does not work for inplace ops or functionalized _copy suffix ops + """ + + def call(self, graph_module: torch.fx.GraphModule): + graph_changed = False + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.transpose.int + ): + # Check if the next node is also a transpose node + tranpose_users = list(node.users.keys()) + dim_0 = node.args[1] + dim_1 = node.args[2] + dim_0, dim_1 = _normalize_dims(node.args[0].meta["val"], dim_0, dim_1) + + for user in tranpose_users: + if ( + user.op == "call_function" + and user.target == torch.ops.aten.transpose.int + ): + # Get the arguments of the current and next transpose nodes + user_dim_0 = user.args[1] + user_dim_1 = user.args[2] + user_dim_0, user_dim_1 = _normalize_dims( + user.args[0].meta["val"], user_dim_0, user_dim_1 + ) + + # Check if the two transpose nodes undo each other + if dim_0 == user_dim_0 and dim_1 == user_dim_1: + graph_changed = True + user.replace_all_uses_with(node.args[0]) + + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.permute.default + ): + # Check if the next node is also a transpose node + permute_users = list(node.users.keys()) + dim_list = node.args[1] + + for user in permute_users: + if ( + user.op == "call_function" + and user.target == torch.ops.aten.permute.default + ): + # Get the arguments of the current and next transpose nodes + user_dim_list = user.args[1] + + # Check if the two permutes undo each other + if dim_list == user_dim_list: + graph_changed = True + user.replace_all_uses_with(node.args[0]) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) diff --git a/extension/llm/export/test_export_passes.py b/extension/llm/export/test_export_passes.py new file mode 100644 index 00000000000..12ce18ebb79 --- /dev/null +++ b/extension/llm/export/test_export_passes.py @@ -0,0 +1,165 @@ +import unittest + +import torch + +from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes + +from torch.export import export_for_training +from torch.testing import FileCheck + + +class RemoveRedundantTransposesPassTest(unittest.TestCase): + def _export(self, model, example_inputs): + exported_module = export_for_training( + model, + example_inputs, + ) + return exported_module.module() + + def _check(self, model, example_inputs, key, before_count, after_count): + gm = self._export(model, example_inputs) + FileCheck().check_count(key, before_count, exactly=True).run(gm.code) + pass_res = RemoveRedundantTransposes()(gm) + FileCheck().check_count(key, after_count, exactly=True).run( + pass_res.graph_module.code + ) + + def test_transpose_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + return x + 1 + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + x = x + 1 + + x = torch.transpose(x, 2, 3) + x = torch.transpose(x, 2, 3) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.transpose.int" + m = TestModule1() + self._check(m, (x,), key, 2, 0) + + m = TestModule2() + self._check(m, (x,), key, 4, 0) + + def test_transpose_no_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + x = x + 1 + + x = torch.transpose(x, 2, 3) + x = torch.transpose(x, 1, 2) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.transpose.int" + + m = TestModule1() + self._check(m, (x,), key, 4, 2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_1 = torch.transpose(x, 1, 2) + x_2 = torch.transpose(x_1, 1, 2) + x_2 = x_2 + 1 + + x = x_1 + 2 + x = torch.transpose(x, 1, 2) + + return x + x_2 + + m = TestModule2() + self._check(m, (x,), key, 3, 2) + + def test_permute_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + return x + 1 + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + x = x + 1 + + x = torch.permute(x, [0, 1, 3, 2]) + x = torch.permute(x, [0, 1, 3, 2]) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.permute.default" + m = TestModule1() + self._check(m, (x,), key, 2, 0) + + m = TestModule2() + self._check(m, (x,), key, 4, 0) + + def test_permute_no_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + x = x + 1 + + x = torch.permute(x, [0, 1, 3, 2]) + x = torch.permute(x, [0, 2, 1, 3]) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.permute.default" + + m = TestModule1() + self._check(m, (x,), key, 4, 2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_1 = torch.permute(x, [0, 2, 1, 3]) + x_2 = torch.permute(x_1, [0, 2, 1, 3]) + x_2 = x_2 + 1 + + x = x_1 + 2 + x = torch.permute(x, [0, 2, 1, 3]) + + return x + x_2 + + m = TestModule2() + self._check(m, (x,), key, 3, 2)