@@ -33,6 +33,7 @@ def setUp(self):
3333 self .num_kv_heads = 8
3434 self .head_dim = 64
3535 self .max_seq_len = 128
36+ self .encoder_max_seq_len = 128
3637 self .rope_base = 500_000
3738 self .scale_factor = 32
3839
@@ -86,16 +87,26 @@ def setUp(self):
8687 max_seq_len = self .max_seq_len ,
8788 )
8889 self .et_mha .load_state_dict (self .tt_mha .state_dict ())
90+
8991 # Common inputs.
9092 seq_len = 10
9193 self .x = torch .randn (1 , seq_len , self .embed_dim )
94+ self .y = torch .randn (1 , seq_len , self .embed_dim )
9295 self .input_pos = torch .arange (seq_len ).unsqueeze (0 ) # shape [1, seq_len]
93- seq_len_dim = torch .export .Dim ("seq_len" , min = 1 , max = 100 )
94- self .dynamic_shapes = (
95- {0 : torch .export .Dim .STATIC , 1 : seq_len_dim , 2 : torch .export .Dim .STATIC },
96- {0 : torch .export .Dim .STATIC , 1 : seq_len_dim , 2 : torch .export .Dim .STATIC },
97- {0 : torch .export .Dim .STATIC , 1 : seq_len_dim },
98- )
96+ self .seq_len_dim = torch .export .Dim ("seq_len" , min = 1 , max = self .max_seq_len )
97+ self .dynamic_shapes = {
98+ "x" : {
99+ 0 : torch .export .Dim .STATIC ,
100+ 1 : self .seq_len_dim ,
101+ 2 : torch .export .Dim .STATIC ,
102+ },
103+ "y" : {
104+ 0 : torch .export .Dim .STATIC ,
105+ 1 : self .seq_len_dim ,
106+ 2 : torch .export .Dim .STATIC ,
107+ },
108+ "input_pos" : {0 : torch .export .Dim .STATIC , 1 : self .seq_len_dim },
109+ }
99110 self .causal_mask = torch .tril (
100111 torch .ones (
101112 size = (self .max_seq_len , self .max_seq_len ),
@@ -110,8 +121,8 @@ def test_attention_eager(self):
110121 assert_close (et_res , tt_res )
111122
112123 # test with kv cache
113- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
114- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
124+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
125+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
115126
116127 et_res = self .et_mha (self .x , self .x ) # Self attention.
117128 tt_res = self .tt_mha (self .x , self .x ) # Self attention.
@@ -144,12 +155,12 @@ def test_attention_export(self):
144155 # Self attention.
145156
146157 # test with kv cache
147- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
148- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
158+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
159+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
149160 with torch .no_grad ():
150161 et_mha_ep = torch .export .export (
151162 self .et_mha ,
152- (self .x , self .x ),
163+ (self .x , self .y ),
153164 kwargs = {"input_pos" : self .input_pos },
154165 dynamic_shapes = self .dynamic_shapes ,
155166 strict = True ,
@@ -166,8 +177,8 @@ def test_attention_aoti(self):
166177 # Self attention.
167178
168179 # test with kv cache
169- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
170- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
180+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
181+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
171182 with torch .no_grad ():
172183 so = torch ._export .aot_compile (
173184 self .et_mha ,
@@ -189,13 +200,13 @@ def test_attention_aoti(self):
189200
190201 def test_attention_executorch (self ):
191202 # Self attention.
192- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
203+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
204+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
194205
195206 with torch .no_grad ():
196207 et_mha_ep = torch .export .export (
197208 self .et_mha ,
198- (self .x , self .x ),
209+ (self .x , self .y ),
199210 kwargs = {"input_pos" : self .input_pos },
200211 dynamic_shapes = self .dynamic_shapes ,
201212 strict = True ,
@@ -222,22 +233,18 @@ def test_attention_executorch(self):
222233
223234 def test_attention_torch_cond_eager (self ):
224235 # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
225- # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
236+ # For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan.
226237 self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
227238 self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
228239
229240 mask = self .causal_mask [self .input_pos , :]
230241 # First run.
231- et_res = self .et_mha (
232- self .x , self .x , mask = mask , input_pos = self .input_pos
233- ) # Self attention with input pos.
234- tt_res = self .tt_mha (
235- self .x , self .x , mask = mask , input_pos = self .input_pos
236- ) # Self attention with input pos.
242+ et_res = self .et_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
243+ tt_res = self .tt_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
237244
238245 assert_close (et_res , tt_res )
239246
240- # Second run test kv cache read. Input pos is [10, 11, ..., 19]
247+ # Second run tests kv cache read. Input pos is [10, 11, ..., 19]
241248 next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
242249
243250 empty_y = torch .full_like (self .x , torch .nan )
@@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self):
246253 tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
247254
248255 assert_close (et_res , tt_res )
256+
257+ def test_attention_torch_cond_export (self ):
258+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
259+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
260+ mask = self .causal_mask [self .input_pos , :]
261+ dynamic_shapes = {
262+ ** self .dynamic_shapes ,
263+ ** {
264+ "mask" : {
265+ 0 : torch .export .Dim .STATIC ,
266+ 1 : self .seq_len_dim ,
267+ 2 : torch .export .Dim .STATIC ,
268+ }
269+ },
270+ }
271+ with torch .no_grad ():
272+ et_mha_ep = torch .export .export (
273+ self .et_mha ,
274+ (self .x , self .y ),
275+ kwargs = {
276+ "mask" : mask ,
277+ "input_pos" : self .input_pos ,
278+ },
279+ dynamic_shapes = dynamic_shapes ,
280+ strict = True ,
281+ )
282+
283+ # First run.
284+ et_res = et_mha_ep .module ()(self .x , self .y , mask = mask , input_pos = self .input_pos )
285+ tt_res = self .tt_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
286+
287+ assert_close (et_res , tt_res )
288+
289+ # Second run tests kv cache read. Input pos is [10, 11, ..., 19]
290+ next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
291+ empty_y = torch .full_like (self .y , torch .nan )
292+ mask = self .causal_mask [next_input_pos , :]
293+ et_res = et_mha_ep .module ()(
294+ self .x , empty_y , mask = mask , input_pos = next_input_pos
295+ )
296+ tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
297+
298+ assert_close (et_res , tt_res )
299+
300+ def test_attention_torch_cond_executorch (self ):
301+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
302+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
303+ mask = self .causal_mask [self .input_pos , :]
304+ dynamic_shapes = {
305+ ** self .dynamic_shapes ,
306+ ** {
307+ "mask" : {
308+ 0 : torch .export .Dim .STATIC ,
309+ 1 : self .seq_len_dim ,
310+ 2 : torch .export .Dim .STATIC ,
311+ }
312+ },
313+ }
314+ with torch .no_grad ():
315+ et_mha_ep = torch .export .export (
316+ self .et_mha ,
317+ (self .x , self .y ),
318+ kwargs = {
319+ "mask" : mask ,
320+ "input_pos" : self .input_pos ,
321+ },
322+ dynamic_shapes = dynamic_shapes ,
323+ strict = True ,
324+ )
325+ et_program = to_edge (
326+ et_mha_ep ,
327+ compile_config = EdgeCompileConfig (
328+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
329+ _check_ir_validity = False ,
330+ ),
331+ ).to_executorch (
332+ config = ExecutorchBackendConfig (
333+ passes = [InitializedMutableBufferPass (["cache_pos" ])],
334+ )
335+ )
336+
337+ # First run.
338+ runtime = Runtime .get ()
339+ program = runtime .load_program (et_program .buffer )
340+ method = program .load_method ("forward" )
341+ et_res = method .execute ((self .x , self .y , mask , self .input_pos ))
342+ tt_res = self .tt_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
343+
344+ assert_close (et_res [0 ], tt_res )
345+
346+ # Second run tests kv cache read. Input pos is [10, 11, ..., 19]
347+ next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
348+ empty_y = torch .full_like (self .y , torch .nan )
349+ mask = self .causal_mask [next_input_pos , :]
350+ et_res = method .execute ((self .x , empty_y , mask , next_input_pos ))
351+ tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
352+
353+ assert_close (et_res [0 ], tt_res )
0 commit comments