@@ -40,7 +40,29 @@ def check_op_counts(
40
40
self .assertTrue (op_counts_match (graph_module , expected_op_counts ))
41
41
42
42
43
- class TestFusionPasses (TestFusionPassesBase ):
43
+ class TestFuseMMWithAddPass (TestFusionPassesBase ):
44
+ def test_no_fuse_for_3d_bias (self ) -> None :
45
+ builder = GraphBuilder ()
46
+ x = builder .placeholder ("x" , torch .randn (4 , 3 , dtype = torch .float32 ))
47
+ y = builder .placeholder ("y" , torch .randn (3 , 5 , dtype = torch .float32 ))
48
+ z = builder .placeholder ("z" , torch .randn (1 , 4 , 5 , dtype = torch .float32 ))
49
+ mm = builder .call_operator (
50
+ op = exir_ops .edge .aten .mm .default ,
51
+ args = (x , y ),
52
+ )
53
+ output = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
54
+ builder .output ([output ])
55
+ original_graph = builder .get_graph_module ()
56
+
57
+ p = FuseMMWithAdd ()
58
+ converted_graph = cast (PassResult , p (original_graph )).graph_module
59
+ converted_graph .graph .eliminate_dead_code ()
60
+ self .assertEqual (
61
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
62
+ )
63
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
64
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 1 )
65
+
44
66
def test_fuse_mm_with_add (self ) -> None :
45
67
builder = GraphBuilder ()
46
68
x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
@@ -176,6 +198,8 @@ def test_keep_mm_add_with_multiple_users(self) -> None:
176
198
self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
177
199
self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 3 )
178
200
201
+
202
+ class TestFusionPasses (TestFusionPassesBase ):
179
203
def test_permute_transpose_fusion (self ) -> None :
180
204
builder = GraphBuilder ()
181
205
x = builder .placeholder ("x" , torch .randn (3 , 1 , 3 , 1 , 4 , dtype = torch .float32 ))
0 commit comments