44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7+ import  os 
8+ import  tempfile 
79import  unittest 
810
911import  torch 
1315    MultiHeadAttention  as  ETMultiHeadAttention ,
1416)
1517from  executorch .runtime  import  Runtime 
18+ from  torch ._inductor .package  import  load_package , package_aoti 
1619from  torch .testing  import  assert_close 
1720from  torchtune .models .llama3_1 ._position_embeddings  import  Llama3ScaledRoPE 
1821from  torchtune .modules .attention  import  MultiHeadAttention  as  TTMultiHeadAttention 
@@ -130,34 +133,62 @@ def test_attention_eager(self):
130133
131134    def  test_attention_export (self ):
132135        # Self attention. 
133-         et_mha_ep  =  torch .export .export (
134-             self .et_mha ,
135-             (self .x , self .x ),
136-             kwargs = {"input_pos" : self .input_pos },
137-             dynamic_shapes = self .dynamic_shapes ,
138-         )
136+ 
137+         # test with kv cache 
138+         self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
139+         self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
140+         with  torch .no_grad ():
141+             et_mha_ep  =  torch .export .export (
142+                 self .et_mha ,
143+                 (self .x , self .x ),
144+                 kwargs = {"input_pos" : self .input_pos },
145+                 dynamic_shapes = self .dynamic_shapes ,
146+             )
139147        et_res  =  et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
140148        tt_res  =  self .tt_mha (self .x , self .x , input_pos = self .input_pos )
141149
142150        assert_close (et_res , tt_res )
143151
144-         # TODO: KV cache. 
145- 
146152    def  test_attention_aoti (self ):
147-         # TODO. 
148-         pass 
153+         # Self attention. 
154+ 
155+         # test with kv cache 
156+         self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
157+         self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
158+         with  torch .no_grad ():
159+             so  =  torch ._export .aot_compile (
160+                 self .et_mha ,
161+                 args = (self .x , self .x ),
162+                 kwargs = {"input_pos" : self .input_pos },
163+                 options = {"aot_inductor.package" : True },
164+                 dynamic_shapes = self .dynamic_shapes ,
165+             )
166+         with  tempfile .TemporaryDirectory () as  tempdir :
167+             path  =  package_aoti (os .path .join (tempdir , "mha.pt2" ), so )
168+             mha_aoti  =  load_package (path )
169+ 
170+             aoti_res  =  mha_aoti (self .x , self .x , input_pos = self .input_pos )
171+             tt_res  =  self .tt_mha (self .x , self .x , input_pos = self .input_pos )
172+             assert_close (aoti_res , tt_res )
149173
150174    def  test_attention_executorch (self ):
151175        # Self attention. 
152-         et_mha_ep  =  torch .export .export (
153-             self .et_mha ,
154-             (self .x , self .x ),
155-             kwargs = {"input_pos" : self .input_pos },
156-             dynamic_shapes = self .dynamic_shapes ,
157-         )
176+         # TODO: Fix kv cache 
177+         # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 
178+         # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) 
179+ 
180+         with  torch .no_grad ():
181+             et_mha_ep  =  torch .export .export (
182+                 self .et_mha ,
183+                 (self .x , self .x ),
184+                 kwargs = {"input_pos" : self .input_pos },
185+                 dynamic_shapes = self .dynamic_shapes ,
186+             )
158187        et_program  =  to_edge (
159188            et_mha_ep ,
160-             compile_config = EdgeCompileConfig (),
189+             compile_config = EdgeCompileConfig (
190+                 _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
191+             ),
161192        ).to_executorch ()
162193        runtime  =  Runtime .get ()
163194        program  =  runtime .load_program (et_program .buffer )
@@ -166,5 +197,3 @@ def test_attention_executorch(self):
166197        tt_res  =  self .tt_mha (self .x , self .x , input_pos = self .input_pos )
167198
168199        assert_close (et_res [0 ], tt_res )
169- 
170-         # TODO: KV cache. 
0 commit comments