File tree Expand file tree Collapse file tree 1 file changed +8
-11
lines changed Expand file tree Collapse file tree 1 file changed +8
-11
lines changed Original file line number Diff line number Diff line change @@ -44,21 +44,18 @@ struct AddMMBranchFusion {
44
44
auto arm2 = n->blocks ()[1 ];
45
45
46
46
auto arm1_start = arm1->nodes ().begin ();
47
- if ((*arm1_start)->kind ().toQualString () != std::string (" aten::addmm" ) &&
48
- (*(++arm1_start))->kind () != prim::Return) {
49
- // Make sure that block0 is solely just the aten::addmm op and the return
50
- return false ;
51
- }
52
-
53
47
auto arm2_start = arm2->nodes ().begin ();
54
- if ((*arm2_start)->kind ().toQualString () != std::string (" aten::matmul" ) &&
48
+
49
+ if ((*arm1_start)->kind ().toQualString () == std::string (" aten::addmm" ) &&
50
+ (*(++arm1_start))->kind () == prim::Return &&
51
+ (*arm2_start)->kind ().toQualString () == std::string (" aten::matmul" ) &&
55
52
(*(++arm2_start))->kind ().toQualString () != std::string (" aten::add_" ) &&
56
- (*(++arm2_start))->kind () ! = prim::Return) {
57
- // Make sure that block1 is solely the return
58
- return false ;
53
+ (*(++arm2_start))->kind () = = prim::Return) {
54
+ // Make sure that block0 is solely just the aten::addmm op and block1 is matmul + add
55
+ return true ;
59
56
}
60
57
61
- return true ;
58
+ return false ;
62
59
}
63
60
64
61
void findAddMMVariantsNodes (Block* b) {
You can’t perform that action at this time.
0 commit comments