1111import  torch 
1212from  executorch .exir  import  EdgeCompileConfig , to_edge 
1313
14+ from  executorch .exir .capture ._config  import  ExecutorchBackendConfig 
15+ from  executorch .exir .passes .init_mutable_pass  import  InitializedMutableBufferPass 
1416from  executorch .extension .llm .modules .attention  import  (
1517    MultiHeadAttention  as  ETMultiHeadAttention ,
1618)
@@ -114,7 +116,7 @@ def test_attention_eager(self):
114116        et_res  =  self .et_mha (self .x , self .x )  # Self attention. 
115117        tt_res  =  self .tt_mha (self .x , self .x )  # Self attention. 
116118
117-         self . assertTrue ( torch . allclose ( et_res , tt_res ) )
119+         assert_close ( et_res , tt_res )
118120        self .et_mha .reset_cache ()
119121        self .tt_mha .reset_cache ()
120122
@@ -125,7 +127,7 @@ def test_attention_eager(self):
125127            self .x , self .x , input_pos = self .input_pos 
126128        )  # Self attention with input pos. 
127129
128-         self . assertTrue ( torch . allclose ( et_res , tt_res ) )
130+         assert_close ( et_res , tt_res )
129131
130132        # test kv cache read. Input pos can be [10, 11, ..., 19] 
131133        next_input_pos  =  torch .arange (10 , 20 ).unsqueeze (0 )
@@ -187,9 +189,8 @@ def test_attention_aoti(self):
187189
188190    def  test_attention_executorch (self ):
189191        # Self attention. 
190-         # TODO: Fix kv cache 
191-         # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 
192-         # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 
192+         self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193+         self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193194
194195        with  torch .no_grad ():
195196            et_mha_ep  =  torch .export .export (
@@ -202,9 +203,15 @@ def test_attention_executorch(self):
202203        et_program  =  to_edge (
203204            et_mha_ep ,
204205            compile_config = EdgeCompileConfig (
205-                 _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
206+                 _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
207+                 _check_ir_validity = False ,
206208            ),
207-         ).to_executorch ()
209+         ).to_executorch (
210+             config = ExecutorchBackendConfig (
211+                 passes = [InitializedMutableBufferPass (["cache_pos" ])],
212+             )
213+         )
214+ 
208215        runtime  =  Runtime .get ()
209216        program  =  runtime .load_program (et_program .buffer )
210217        method  =  program .load_method ("forward" )
@@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
219226        self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
220227        self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
221228
222-         # mask 
223229        mask  =  self .causal_mask [self .input_pos , :]
224-         # First run 
230+         # First run.  
225231        et_res  =  self .et_mha (
226232            self .x , self .x , mask = mask , input_pos = self .input_pos 
227233        )  # Self attention with input pos. 
228234        tt_res  =  self .tt_mha (
229235            self .x , self .x , mask = mask , input_pos = self .input_pos 
230236        )  # Self attention with input pos. 
231237
232-         self . assertTrue ( torch . allclose ( et_res , tt_res ) )
238+         assert_close ( et_res , tt_res )
233239
234240        # Second run test kv cache read. Input pos is [10, 11, ..., 19] 
235241        next_input_pos  =  torch .arange (10 , 20 ).unsqueeze (0 )
236242
237243        empty_y  =  torch .full_like (self .x , torch .nan )
238244        mask  =  self .causal_mask [next_input_pos , :]
239-         et_res  =  self .et_mha (
240-             self .x , empty_y , mask = mask , input_pos = next_input_pos 
241-         )  # Self attention with input pos. 
242-         tt_res  =  self .tt_mha (
243-             self .x , None , mask = mask , input_pos = next_input_pos 
244-         )  # Self attention with input pos. 
245+         et_res  =  self .et_mha (self .x , empty_y , mask = mask , input_pos = next_input_pos )
246+         tt_res  =  self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
245247
246248        assert_close (et_res , tt_res )
0 commit comments