diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 62c0806dada2e..df6ce0fe1b037 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1850,9 +1850,11 @@ class SelectionDAG { /// Get the specified node if it's already available, or else return NULL. LLVM_ABI SDNode *getNodeIfExists(unsigned Opcode, SDVTList VTList, ArrayRef Ops, - const SDNodeFlags Flags); + const SDNodeFlags Flags, + bool AllowCommute = false); LLVM_ABI SDNode *getNodeIfExists(unsigned Opcode, SDVTList VTList, - ArrayRef Ops); + ArrayRef Ops, + bool AllowCommute = false); /// Check if a node exists without modifying its flags. LLVM_ABI bool doesNodeExist(unsigned Opcode, SDVTList VTList, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 6ea2e2708c162..538f318451a17 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -11848,25 +11848,37 @@ SDValue SelectionDAG::getTargetInsertSubreg(int SRIdx, const SDLoc &DL, EVT VT, /// getNodeIfExists - Get the specified node if it's already available, or /// else return NULL. SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList, - ArrayRef Ops) { + ArrayRef Ops, + bool AllowCommute) { SDNodeFlags Flags; if (Inserter) Flags = Inserter->getFlags(); - return getNodeIfExists(Opcode, VTList, Ops, Flags); + return getNodeIfExists(Opcode, VTList, Ops, Flags, AllowCommute); } SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList, ArrayRef Ops, - const SDNodeFlags Flags) { - if (VTList.VTs[VTList.NumVTs - 1] != MVT::Glue) { - FoldingSetNodeID ID; - AddNodeIDNode(ID, Opcode, VTList, Ops); - void *IP = nullptr; - if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) { - E->intersectFlagsWith(Flags); - return E; + const SDNodeFlags Flags, + bool AllowCommute) { + auto Lookup = [&](ArrayRef LookupOps) -> SDNode * { + if (VTList.VTs[VTList.NumVTs - 1] != MVT::Glue) { + FoldingSetNodeID ID; + AddNodeIDNode(ID, Opcode, VTList, LookupOps); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) { + E->intersectFlagsWith(Flags); + return E; + } } - } + return nullptr; + }; + + if (SDNode *Existing = Lookup(Ops)) + return Existing; + + if (AllowCommute && TLI->isCommutativeBinOp(Opcode)) + return Lookup({Ops[1], Ops[0]}); + return nullptr; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index dc8e7c84f5e2c..00ae4edca2fa4 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -26192,9 +26192,10 @@ static SDValue performFlagSettingCombine(SDNode *N, return DCI.CombineTo(N, Res, SDValue(N, 1)); } - // Combine identical generic nodes into this node, re-using the result. + // Combine equivalent generic nodes into this node, re-using the result. if (SDNode *Generic = DCI.DAG.getNodeIfExists( - GenericOpcode, DCI.DAG.getVTList(VT), {LHS, RHS})) + GenericOpcode, DCI.DAG.getVTList(VT), {LHS, RHS}, + /*AllowCommute=*/true)) DCI.CombineTo(Generic, SDValue(N, 0)); return SDValue(); diff --git a/llvm/test/CodeGen/AArch64/adds_cmn.ll b/llvm/test/CodeGen/AArch64/adds_cmn.ll index aa070b7886ba5..9b456a5419d61 100644 --- a/llvm/test/CodeGen/AArch64/adds_cmn.ll +++ b/llvm/test/CodeGen/AArch64/adds_cmn.ll @@ -22,10 +22,8 @@ entry: define { i32, i32 } @adds_cmn_c(i32 noundef %x, i32 noundef %y) { ; CHECK-LABEL: adds_cmn_c: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: cmn w0, w1 -; CHECK-NEXT: add w1, w1, w0 -; CHECK-NEXT: cset w8, lo -; CHECK-NEXT: mov w0, w8 +; CHECK-NEXT: adds w1, w0, w1 +; CHECK-NEXT: cset w0, lo ; CHECK-NEXT: ret entry: %0 = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %x, i32 %y) diff --git a/llvm/test/CodeGen/AArch64/sat-add.ll b/llvm/test/CodeGen/AArch64/sat-add.ll index ecd48d6b7c65b..149b4c4fd26c9 100644 --- a/llvm/test/CodeGen/AArch64/sat-add.ll +++ b/llvm/test/CodeGen/AArch64/sat-add.ll @@ -290,8 +290,7 @@ define i32 @unsigned_sat_variable_i32_using_cmp_sum(i32 %x, i32 %y) { define i32 @unsigned_sat_variable_i32_using_cmp_notval(i32 %x, i32 %y) { ; CHECK-LABEL: unsigned_sat_variable_i32_using_cmp_notval: ; CHECK: // %bb.0: -; CHECK-NEXT: add w8, w0, w1 -; CHECK-NEXT: cmn w1, w0 +; CHECK-NEXT: adds w8, w1, w0 ; CHECK-NEXT: csinv w0, w8, wzr, lo ; CHECK-NEXT: ret %noty = xor i32 %y, -1 @@ -331,8 +330,7 @@ define i64 @unsigned_sat_variable_i64_using_cmp_sum(i64 %x, i64 %y) { define i64 @unsigned_sat_variable_i64_using_cmp_notval(i64 %x, i64 %y) { ; CHECK-LABEL: unsigned_sat_variable_i64_using_cmp_notval: ; CHECK: // %bb.0: -; CHECK-NEXT: add x8, x0, x1 -; CHECK-NEXT: cmn x1, x0 +; CHECK-NEXT: adds x8, x1, x0 ; CHECK-NEXT: csinv x0, x8, xzr, lo ; CHECK-NEXT: ret %noty = xor i64 %y, -1