@@ -94,7 +94,10 @@ def test_attention_eager(self):
9494 et_res = self .et_mha (self .x , self .x ) # Self attention.
9595 tt_res = self .tt_mha (self .x , self .x ) # Self attention.
9696
97- self .assertTrue (torch .allclose (et_res , tt_res ))
97+ self .assertTrue (
98+ torch .allclose (et_res , tt_res ),
99+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
100+ )
98101
99102 # test with kv cache
100103 self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
@@ -136,7 +139,10 @@ def test_attention_export(self):
136139 )
137140 et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
138141 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
139- self .assertTrue (torch .allclose (et_res , tt_res ))
142+ self .assertTrue (
143+ torch .allclose (et_res , tt_res ),
144+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
145+ )
140146
141147 # TODO: KV cache.
142148
@@ -162,6 +168,9 @@ def test_attention_executorch(self):
162168 et_res = method .execute ((self .x , self .x , self .input_pos ))
163169 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
164170
165- self .assertTrue (torch .allclose (et_res [0 ], tt_res , atol = 1e-06 ))
171+ self .assertTrue (
172+ torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
173+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
174+ )
166175
167176 # TODO: KV cache.
0 commit comments