@@ -27,7 +27,7 @@ def setUp(self):
2727 torch .manual_seed (0 )
2828 # Constants
2929 self .embed_dim = 2048
30- self .num_heads = 32
30+ self .num_heads = 8
3131 self .num_kv_heads = 8
3232 self .head_dim = 64
3333 self .max_seq_len = 128
@@ -46,7 +46,9 @@ def setUp(self):
4646 self .embed_dim , self .num_kv_heads * self .head_dim , bias = False
4747 )
4848 self .v_proj .weight .requires_grad = False
49- self .output_proj = torch .nn .Linear (self .embed_dim , self .embed_dim , bias = False )
49+ self .output_proj = torch .nn .Linear (
50+ self .num_heads * self .head_dim , self .embed_dim , bias = False
51+ )
5052 self .pos_embeddings = Llama3ScaledRoPE (
5153 dim = self .head_dim ,
5254 max_seq_len = self .max_seq_len ,
@@ -92,6 +94,12 @@ def setUp(self):
9294 {0 : torch .export .Dim .STATIC , 1 : seq_len_dim , 2 : torch .export .Dim .STATIC },
9395 {0 : torch .export .Dim .STATIC , 1 : seq_len_dim },
9496 )
97+ self .causal_mask = torch .tril (
98+ torch .ones (
99+ size = (self .max_seq_len , self .max_seq_len ),
100+ dtype = torch .bool ,
101+ )
102+ )
95103
96104 def test_attention_eager (self ):
97105 et_res = self .et_mha (self .x , self .x ) # Self attention.
@@ -197,3 +205,35 @@ def test_attention_executorch(self):
197205 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
198206
199207 assert_close (et_res [0 ], tt_res )
208+
209+ def test_attention_torch_cond_eager (self ):
210+ # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
211+ # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
212+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
213+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
214+
215+ # mask
216+ mask = self .causal_mask [self .input_pos , :]
217+ # First run
218+ et_res = self .et_mha (
219+ self .x , self .x , mask = mask , input_pos = self .input_pos
220+ ) # Self attention with input pos.
221+ tt_res = self .tt_mha (
222+ self .x , self .x , mask = mask , input_pos = self .input_pos
223+ ) # Self attention with input pos.
224+
225+ self .assertTrue (torch .allclose (et_res , tt_res ))
226+
227+ # Second run test kv cache read. Input pos is [10, 11, ..., 19]
228+ next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
229+
230+ empty_y = torch .full_like (self .x , torch .nan )
231+ mask = self .causal_mask [next_input_pos , :]
232+ et_res = self .et_mha (
233+ self .x , empty_y , mask = mask , input_pos = next_input_pos
234+ ) # Self attention with input pos.
235+ tt_res = self .tt_mha (
236+ self .x , None , mask = mask , input_pos = next_input_pos
237+ ) # Self attention with input pos.
238+
239+ assert_close (et_res , tt_res )
0 commit comments