99# Example script for exporting Llama2 to flatbuffer
1010
1111import math
12- from typing import Tuple , Union
12+ from typing import Tuple , Union , Optional
1313
1414import torch
1515
2222class SDPACustom (torch .nn .Module ):
2323 def __init__ (
2424 self ,
25- kv_cache : Union [KVCache , QuantizedKVCache ],
26- dim : int ,
25+ kv_cache : Optional [ Union [KVCache , QuantizedKVCache ]] = None ,
26+ dim : int = - 1 ,
2727 ):
2828 super ().__init__ ()
2929 # Custom op only supports float32 currently. Converting to/from float32 is
3030 # faster than not having the op.
3131 self .kv_cache = kv_cache
32- if not isinstance (kv_cache , QuantizedKVCache ):
32+ if kv_cache is None :
33+ pass
34+ elif not isinstance (kv_cache , QuantizedKVCache ):
3335 self .kv_cache = kv_cache .to (torch .float )
3436 else :
3537 assert (
@@ -44,8 +46,8 @@ def forward(
4446 k : torch .Tensor ,
4547 v : torch .Tensor ,
4648 bsz ,
47- seqlen ,
48- mask ,
49+ seqlen = None ,
50+ mask = None ,
4951 ):
5052 # Custom op only supports float32 currently. Converting to/from float32 is
5153 # faster than not having the op.
@@ -54,9 +56,20 @@ def forward(
5456 k = k .to (dtype = torch .float )
5557 v = v .to (dtype = torch .float )
5658
57- k_cache = self .kv_cache .k_cache
58- v_cache = self .kv_cache .v_cache
59- if hasattr (self .kv_cache , "quantized_cache_dtype" ):
59+ k_cache = self .kv_cache .k_cache if self .kv_cache is not None else None
60+ v_cache = self .kv_cache .v_cache if self .kv_cache is not None else None
61+
62+ if self .kv_cache is None :
63+ output = torch .ops .llama .custom_sdpa (
64+ q ,
65+ k ,
66+ v ,
67+ input_pos ,
68+ None , # Attention mask
69+ 0 , # dropout probability. Ignored by the code
70+ False , # is_causal
71+ )
72+ elif isinstance (self .kv_cache , QuantizedKVCache ):
6073 # updated quantize cache, scale and zero points
6174 # returns dequantized kv cache
6275 # Not most optimal. Optimizations to follow next
@@ -99,7 +112,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99112
100113
101114def replace_sdpa_with_custom_op (module : torch .nn .Module ) -> torch .nn .Module :
102- from executorch .extension .llm .custom_ops import custom_ops # noqa
115+ from executorch .extension .llm .custom_ops import custom_ops
103116
104117 _replace_sdpa_with_custom_op (module )
105118 return module
0 commit comments