44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import os
8+ import tempfile
79import unittest
810
911import torch
1315 MultiHeadAttention as ETMultiHeadAttention ,
1416)
1517from executorch .runtime import Runtime
18+ from torch ._inductor .package import load_package , package_aoti
1619from torch .testing import assert_close
1720from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
1821from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
@@ -130,34 +133,62 @@ def test_attention_eager(self):
130133
131134 def test_attention_export (self ):
132135 # Self attention.
133- et_mha_ep = torch .export .export (
134- self .et_mha ,
135- (self .x , self .x ),
136- kwargs = {"input_pos" : self .input_pos },
137- dynamic_shapes = self .dynamic_shapes ,
138- )
136+
137+ # test with kv cache
138+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
139+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
140+ with torch .no_grad ():
141+ et_mha_ep = torch .export .export (
142+ self .et_mha ,
143+ (self .x , self .x ),
144+ kwargs = {"input_pos" : self .input_pos },
145+ dynamic_shapes = self .dynamic_shapes ,
146+ )
139147 et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
140148 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
141149
142150 assert_close (et_res , tt_res )
143151
144- # TODO: KV cache.
145-
146152 def test_attention_aoti (self ):
147- # TODO.
148- pass
153+ # Self attention.
154+
155+ # test with kv cache
156+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
157+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
158+ with torch .no_grad ():
159+ so = torch ._export .aot_compile (
160+ self .et_mha ,
161+ args = (self .x , self .x ),
162+ kwargs = {"input_pos" : self .input_pos },
163+ options = {"aot_inductor.package" : True },
164+ dynamic_shapes = self .dynamic_shapes ,
165+ )
166+ with tempfile .TemporaryDirectory () as tempdir :
167+ path = package_aoti (os .path .join (tempdir , "mha.pt2" ), so )
168+ mha_aoti = load_package (path )
169+
170+ aoti_res = mha_aoti (self .x , self .x , input_pos = self .input_pos )
171+ tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
172+ assert_close (aoti_res , tt_res )
149173
150174 def test_attention_executorch (self ):
151175 # Self attention.
152- et_mha_ep = torch .export .export (
153- self .et_mha ,
154- (self .x , self .x ),
155- kwargs = {"input_pos" : self .input_pos },
156- dynamic_shapes = self .dynamic_shapes ,
157- )
176+ # TODO: Fix kv cache
177+ # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
178+ # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
179+
180+ with torch .no_grad ():
181+ et_mha_ep = torch .export .export (
182+ self .et_mha ,
183+ (self .x , self .x ),
184+ kwargs = {"input_pos" : self .input_pos },
185+ dynamic_shapes = self .dynamic_shapes ,
186+ )
158187 et_program = to_edge (
159188 et_mha_ep ,
160- compile_config = EdgeCompileConfig (),
189+ compile_config = EdgeCompileConfig (
190+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
191+ ),
161192 ).to_executorch ()
162193 runtime = Runtime .get ()
163194 program = runtime .load_program (et_program .buffer )
@@ -166,5 +197,3 @@ def test_attention_executorch(self):
166197 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
167198
168199 assert_close (et_res [0 ], tt_res )
169-
170- # TODO: KV cache.
0 commit comments