@@ -92,7 +92,10 @@ def test_attention_eager(self):
9292 et_res = self .et_mha (self .x , self .x ) # Self attention.
9393 tt_res = self .tt_mha (self .x , self .x ) # Self attention.
9494
95- self .assertTrue (torch .allclose (et_res , tt_res ))
95+ self .assertTrue (
96+ torch .allclose (et_res , tt_res ),
97+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
98+ )
9699
97100 # TODO: KV cache.
98101 # self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
@@ -113,7 +116,10 @@ def test_attention_export(self):
113116 )
114117 et_res = et_mha_ep .module ()(self .x , self .x )
115118 tt_res = self .tt_mha (self .x , self .x )
116- self .assertTrue (torch .allclose (et_res , tt_res ))
119+ self .assertTrue (
120+ torch .allclose (et_res , tt_res ),
121+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
122+ )
117123
118124 # TODO: KV cache.
119125
@@ -139,6 +145,9 @@ def test_attention_executorch(self):
139145 et_res = method .execute ((self .x , self .x ))
140146 tt_res = self .tt_mha (self .x , self .x )
141147
142- self .assertTrue (torch .allclose (et_res [0 ], tt_res , atol = 1e-06 ))
148+ self .assertTrue (
149+ torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
150+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
151+ )
143152
144153 # TODO: KV cache.
0 commit comments