Skip to content

Commit 54bed12

Browse files
authored
Merge pull request #224 from NVIDIA/fuse_addmm_fix
Explicit pattern matching for add + MM fusion
2 parents 72bf74b + 4c94533 commit 54bed12

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

core/lowering/passes/fuse_addmm_branches.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,18 @@ struct AddMMBranchFusion {
4444
auto arm2 = n->blocks()[1];
4545

4646
auto arm1_start = arm1->nodes().begin();
47-
if ((*arm1_start)->kind().toQualString() != std::string("aten::addmm") &&
48-
(*(++arm1_start))->kind() != prim::Return) {
49-
// Make sure that block0 is solely just the aten::addmm op and the return
50-
return false;
51-
}
52-
5347
auto arm2_start = arm2->nodes().begin();
54-
if ((*arm2_start)->kind().toQualString() != std::string("aten::matmul") &&
48+
49+
if ((*arm1_start)->kind().toQualString() == std::string("aten::addmm") &&
50+
(*(++arm1_start))->kind() == prim::Return &&
51+
(*arm2_start)->kind().toQualString() == std::string("aten::matmul") &&
5552
(*(++arm2_start))->kind().toQualString() != std::string("aten::add_") &&
56-
(*(++arm2_start))->kind() != prim::Return) {
57-
// Make sure that block1 is solely the return
58-
return false;
53+
(*(++arm2_start))->kind() == prim::Return) {
54+
// Make sure that block0 is solely just the aten::addmm op and block1 is matmul + add
55+
return true;
5956
}
6057

61-
return true;
58+
return false;
6259
}
6360

6461
void findAddMMVariantsNodes(Block* b) {

0 commit comments

Comments
 (0)