1111import torch
1212from executorch .exir import EdgeCompileConfig , to_edge
1313
14+ from executorch .exir .capture ._config import ExecutorchBackendConfig
15+ from executorch .exir .passes .init_mutable_pass import InitializedMutableBufferPass
1416from executorch .extension .llm .modules .attention import (
1517 MultiHeadAttention as ETMultiHeadAttention ,
1618)
@@ -114,7 +116,7 @@ def test_attention_eager(self):
114116 et_res = self .et_mha (self .x , self .x ) # Self attention.
115117 tt_res = self .tt_mha (self .x , self .x ) # Self attention.
116118
117- self . assertTrue ( torch . allclose ( et_res , tt_res ) )
119+ assert_close ( et_res , tt_res )
118120 self .et_mha .reset_cache ()
119121 self .tt_mha .reset_cache ()
120122
@@ -125,7 +127,7 @@ def test_attention_eager(self):
125127 self .x , self .x , input_pos = self .input_pos
126128 ) # Self attention with input pos.
127129
128- self . assertTrue ( torch . allclose ( et_res , tt_res ) )
130+ assert_close ( et_res , tt_res )
129131
130132 # test kv cache read. Input pos can be [10, 11, ..., 19]
131133 next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
@@ -187,9 +189,8 @@ def test_attention_aoti(self):
187189
188190 def test_attention_executorch (self ):
189191 # Self attention.
190- # TODO: Fix kv cache
191- # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192- # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193194
194195 with torch .no_grad ():
195196 et_mha_ep = torch .export .export (
@@ -202,9 +203,15 @@ def test_attention_executorch(self):
202203 et_program = to_edge (
203204 et_mha_ep ,
204205 compile_config = EdgeCompileConfig (
205- _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
206+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
207+ _check_ir_validity = False ,
206208 ),
207- ).to_executorch ()
209+ ).to_executorch (
210+ config = ExecutorchBackendConfig (
211+ passes = [InitializedMutableBufferPass (["cache_pos" ])],
212+ )
213+ )
214+
208215 runtime = Runtime .get ()
209216 program = runtime .load_program (et_program .buffer )
210217 method = program .load_method ("forward" )
@@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
219226 self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
220227 self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
221228
222- # mask
223229 mask = self .causal_mask [self .input_pos , :]
224- # First run
230+ # First run.
225231 et_res = self .et_mha (
226232 self .x , self .x , mask = mask , input_pos = self .input_pos
227233 ) # Self attention with input pos.
228234 tt_res = self .tt_mha (
229235 self .x , self .x , mask = mask , input_pos = self .input_pos
230236 ) # Self attention with input pos.
231237
232- self . assertTrue ( torch . allclose ( et_res , tt_res ) )
238+ assert_close ( et_res , tt_res )
233239
234240 # Second run test kv cache read. Input pos is [10, 11, ..., 19]
235241 next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
236242
237243 empty_y = torch .full_like (self .x , torch .nan )
238244 mask = self .causal_mask [next_input_pos , :]
239- et_res = self .et_mha (
240- self .x , empty_y , mask = mask , input_pos = next_input_pos
241- ) # Self attention with input pos.
242- tt_res = self .tt_mha (
243- self .x , None , mask = mask , input_pos = next_input_pos
244- ) # Self attention with input pos.
245+ et_res = self .et_mha (self .x , empty_y , mask = mask , input_pos = next_input_pos )
246+ tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
245247
246248 assert_close (et_res , tt_res )
0 commit comments