99# Example script for exporting Llama2 to flatbuffer
1010
1111import math
12- from typing import Tuple , Union
12+ from typing import Tuple , Union , Optional
1313
1414import torch
1515
2222class SDPACustom (torch .nn .Module ):
2323 def __init__ (
2424 self ,
25- kv_cache : Union [KVCache , QuantizedKVCache ],
26- dim : int ,
25+ kv_cache : Optional [Union [KVCache , QuantizedKVCache ]] = None ,
26+ dim : int = - 1 ,
27+ is_causal = True ,
2728 ):
2829 super ().__init__ ()
2930 # Custom op only supports float32 currently. Converting to/from float32 is
3031 # faster than not having the op.
3132 self .kv_cache = kv_cache
32- if not isinstance (kv_cache , QuantizedKVCache ):
33+ if kv_cache is None :
34+ pass
35+ elif not isinstance (kv_cache , QuantizedKVCache ):
3336 self .kv_cache = kv_cache .to (torch .float )
3437 else :
3538 assert (
3639 kv_cache .cache_fp_type == torch .float32
3740 ), "Only float32 is supported for custom SDPA"
3841 self .dim = dim
42+ self .is_causal = is_causal
3943
4044 def forward (
4145 self ,
@@ -44,8 +48,8 @@ def forward(
4448 k : torch .Tensor ,
4549 v : torch .Tensor ,
4650 bsz ,
47- seqlen ,
48- mask ,
51+ seqlen = None ,
52+ mask = None ,
4953 ):
5054 # Custom op only supports float32 currently. Converting to/from float32 is
5155 # faster than not having the op.
@@ -54,9 +58,20 @@ def forward(
5458 k = k .to (dtype = torch .float )
5559 v = v .to (dtype = torch .float )
5660
57- k_cache = self .kv_cache .k_cache
58- v_cache = self .kv_cache .v_cache
59- if hasattr (self .kv_cache , "quantized_cache_dtype" ):
61+ k_cache = self .kv_cache .k_cache if self .kv_cache is not None else None
62+ v_cache = self .kv_cache .v_cache if self .kv_cache is not None else None
63+
64+ if self .kv_cache is None :
65+ output = torch .ops .llama .custom_sdpa (
66+ q ,
67+ k ,
68+ v ,
69+ input_pos ,
70+ None , # Attention mask
71+ 0 , # dropout probability. Ignored by the code
72+ self .is_causal , # is_causal
73+ )
74+ elif isinstance (self .kv_cache , QuantizedKVCache ):
6075 # updated quantize cache, scale and zero points
6176 # returns dequantized kv cache
6277 # Not most optimal. Optimizations to follow next
@@ -68,7 +83,7 @@ def forward(
6883 input_pos [0 ].item (),
6984 None , # Attention mask
7085 0 , # dropout probability. Ignored by the code
71- True , # is_causal
86+ self . is_causal , # is_causal
7287 )
7388 else :
7489 output = torch .ops .llama .sdpa_with_kv_cache (
@@ -81,7 +96,7 @@ def forward(
8196 seqlen ,
8297 None , # Attention mask
8398 0 , # dropout probability. Ignored by the code
84- True , # is_causal
99+ self . is_causal , # is_causal
85100 )
86101 return output .view (bsz , seqlen , self .dim ).to (dtype = input_dtype )
87102
@@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99114
100115
101116def replace_sdpa_with_custom_op (module : torch .nn .Module ) -> torch .nn .Module :
102- from executorch .extension .llm .custom_ops import custom_ops # noqa
117+ from executorch .extension .llm .custom_ops import custom_ops
103118
104119 _replace_sdpa_with_custom_op (module )
105120 return module
0 commit comments