@@ -44,16 +44,6 @@ def forward(self, x, y):
4444        return  torch .bmm (x , y )
4545
4646
47- class  MatMul (torch .nn .Module ):
48-     test_data_generators  =  {
49-         "rand_3d" : lambda : (torch .rand (2 , 3 , 5 ), torch .rand (2 , 5 , 2 )),
50-         "rand_4d" : lambda : (torch .rand (1 , 2 , 3 , 5 ), torch .rand (1 , 2 , 5 , 2 )),
51-     }
52- 
53-     def  forward (self , x , y ):
54-         return  torch .matmul (x , y )
55- 
56- 
5747class  BMMSingleInput (torch .nn .Module ):
5848    test_data_generators  =  {
5949        "rand_3d_1" : lambda : (torch .rand (20 , 3 , 3 ),),
@@ -81,26 +71,14 @@ def test_bmm_tosa_MI_single_input(test_data: input_t1):
8171    pipeline .run ()
8272
8373
84- @common .parametrize ("test_data" , MatMul .test_data_generators ) 
85- def  test_mm_tosa_MI (test_data : input_t1 ):
86-     pipeline  =  TosaPipelineMI [input_t1 ](MatMul (), test_data (), aten_op_mm , exir_op_mm )
87-     pipeline .run ()
88- 
89- 
90- @common .parametrize ("test_data" , MatMul .test_data_generators ) 
91- def  test_mm_tosa_BI (test_data : input_t1 ):
92-     pipeline  =  TosaPipelineBI [input_t1 ](MatMul (), test_data (), aten_op_mm , exir_op_mm )
93-     pipeline .run ()
94- 
95- 
96- @pytest .mark .flaky (reruns = 5 )  # TODO: Investigate flakyness (MLETORCH-534)  
9774@common .parametrize ("test_data" , BMM .test_data_generators ) 
9875def  test_bmm_tosa_BI (test_data : input_t1 ):
99-     pipeline  =  TosaPipelineBI [input_t1 ](BMM (), test_data (), aten_op_bmm , exir_op_bmm )
76+     pipeline  =  TosaPipelineBI [input_t1 ](
77+         BMM (), test_data (), aten_op_bmm , exir_op_bmm , qtol = 1 
78+     )
10079    pipeline .run ()
10180
10281
103- @pytest .mark .flaky (reruns = 5 )  # TODO: Investigate flakyness (MLETORCH-534)  
10482@common .parametrize ("test_data" , BMMSingleInput .test_data_generators ) 
10583def  test_bmm_tosa_BI_single_input (test_data : input_t1 ):
10684    pipeline  =  TosaPipelineBI [input_t1 ](
0 commit comments