diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index bd2c9338275..6b66abbda01 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -44,16 +44,6 @@ def forward(self, x, y): return torch.bmm(x, y) -class MatMul(torch.nn.Module): - test_data_generators = { - "rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), - "rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), - } - - def forward(self, x, y): - return torch.matmul(x, y) - - class BMMSingleInput(torch.nn.Module): test_data_generators = { "rand_3d_1": lambda: (torch.rand(20, 3, 3),), @@ -81,26 +71,14 @@ def test_bmm_tosa_MI_single_input(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", MatMul.test_data_generators) -def test_mm_tosa_MI(test_data: input_t1): - pipeline = TosaPipelineMI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm) - pipeline.run() - - -@common.parametrize("test_data", MatMul.test_data_generators) -def test_mm_tosa_BI(test_data: input_t1): - pipeline = TosaPipelineBI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm) - pipeline.run() - - -@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness (MLETORCH-534) @common.parametrize("test_data", BMM.test_data_generators) def test_bmm_tosa_BI(test_data: input_t1): - pipeline = TosaPipelineBI[input_t1](BMM(), test_data(), aten_op_bmm, exir_op_bmm) + pipeline = TosaPipelineBI[input_t1]( + BMM(), test_data(), aten_op_bmm, exir_op_bmm, qtol=1 + ) pipeline.run() -@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness (MLETORCH-534) @common.parametrize("test_data", BMMSingleInput.test_data_generators) def test_bmm_tosa_BI_single_input(test_data: input_t1): pipeline = TosaPipelineBI[input_t1]( diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index a5a3b4b98b9..9c3ce443bfd 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -41,7 +41,7 @@ def test_mm_tosa_MI(test_data: Tuple): @common.parametrize("test_data", MM.test_data_generators) def test_mm_tosa_BI(test_data: Tuple): - TosaPipelineBI[test_t](MM(), test_data(), MM.aten_op, MM.exir_op).run() + TosaPipelineBI[test_t](MM(), test_data(), MM.aten_op, MM.exir_op, qtol=1).run() @common.parametrize("test_data", MM.test_data_generators)