-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[RISCV] Add disjoint or patterns for vwadd[u].vv #136716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
DAGCombiner::hoistLogicOpWithSameOpcodeHands will hoist (or disjoint (ext a), (ext b)) -> (ext (or disjoint a, b)) So this adds a pattern to match vwadd[u].vv in this case. We have to teach the combine to preserve the disjoint flag, and add a generic PatFrag for a disjoint or. This is meant to be a follow up to llvm#136677 which would allow us to remove the target hook added there.
|
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesDAGCombiner::hoistLogicOpWithSameOpcodeHands will hoist (or disjoint (ext a), (ext b)) -> (ext (or disjoint a, b)) So this adds a pattern to match vwadd[u].vv in this case. We have to teach the combine to preserve the disjoint flag, and add a generic PatFrag for a disjoint or. This is meant to be a follow up to #136677 which would allow us to remove the target hook added there. Full diff: https://github.com/llvm/llvm-project/pull/136716.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 9c241b6c4df0f..20ef517426cf8 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -1113,6 +1113,10 @@ def not : PatFrag<(ops node:$in), (xor node:$in, -1)>;
def vnot : PatFrag<(ops node:$in), (xor node:$in, immAllOnesV)>;
def ineg : PatFrag<(ops node:$in), (sub 0, node:$in)>;
+def or_disjoint : PatFrag<(ops node:$x, node:$y), (or node:$x, node:$y), [{
+ return N->getFlags().hasDisjoint();
+}]>;
+
def zanyext : PatFrags<(ops node:$op),
[(zext node:$op),
(anyext node:$op)]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b175e35385ec6..8cfcd2be8c61c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5982,7 +5982,9 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
return SDValue();
// logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
- SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ SDNodeFlags LogicFlags;
+ LogicFlags.setDisjoint(N->getFlags().hasDisjoint());
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
if (HandOpcode == ISD::SIGN_EXTEND_INREG)
return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
return DAG.getNode(HandOpcode, DL, VT, Logic);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index b2c5261ae6c2d..71893e85bcb91 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -912,6 +912,25 @@ defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, sext_oneuse, "PseudoVWADD">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, zext_oneuse, "PseudoVWADDU">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, anyext_oneuse, "PseudoVWADDU">;
+// DAGCombiner::hoistLogicOpWithSameOpcodeHands may hoist disjoint ors
+// to (ext (or disjoint (a, b)))
+multiclass VPatWidenOrDisjoint_VV<SDNode extop, string instruction_name> {
+ foreach vtiToWti = AllWidenableIntVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
+ def : Pat<(wti.Vector (extop (vti.Vector (or_disjoint vti.RegClass:$rs2, vti.RegClass:$rs1)))),
+ (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
+ (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+ vti.RegClass:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
+ }
+ }
+}
+defm : VPatWidenOrDisjoint_VV<sext, "PseudoVWADD">;
+defm : VPatWidenOrDisjoint_VV<zext, "PseudoVWADDU">;
+defm : VPatWidenOrDisjoint_VV<anyext, "PseudoVWADDU">;
+
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, sext_oneuse, "PseudoVWSUB">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, zext_oneuse, "PseudoVWSUBU">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, anyext_oneuse, "PseudoVWSUBU">;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 3f5d42f89337b..149950484c477 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1417,15 +1417,12 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v
ret <vscale x 2 x i32> %add
}
-; TODO: We could select vwaddu.vv, but when both arms of the or are the same
-; DAGCombiner::hoistLogicOpWithSameOpcodeHands moves the zext above the or.
define <vscale x 2 x i32> @vwaddu_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwaddu_vv_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vv v9, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v8, v9
+; CHECK-NEXT: vwaddu.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
; CHECK-NEXT: ret
%x.i32 = zext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
%y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
@@ -1433,15 +1430,12 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsc
ret <vscale x 2 x i32> %or
}
-; TODO: We could select vwadd.vv, but when both arms of the or are the same
-; DAGCombiner::hoistLogicOpWithSameOpcodeHands moves the zext above the or.
define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwadd_vv_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vv v9, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vsext.vf2 v8, v9
+; CHECK-NEXT: vwadd.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
; CHECK-NEXT: ret
%x.i32 = sext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
%y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
|
| def vnot : PatFrag<(ops node:$in), (xor node:$in, immAllOnesV)>; | ||
| def ineg : PatFrag<(ops node:$in), (sub 0, node:$in)>; | ||
|
|
||
| def or_disjoint : PatFrag<(ops node:$x, node:$y), (or node:$x, node:$y), [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the existing or_is_add and add_like in RISCVInstrInfo.td. We can move into generic, but let's do that separately.
| // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y) | ||
| SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); | ||
| SDNodeFlags LogicFlags; | ||
| LogicFlags.setDisjoint(N->getFlags().hasDisjoint()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure I understand this right. We've got (or disjoint (ext a), (ext b)) and we're turning that into (ext (or disjoint a, b)). Right?
For zext, this is fine. For sext, this is fine. For the in_reg variants, I am not sure if this is fine or not. What if a and b had the same high bit set and we're doing a sext_in_reg which clears that bit? I think that violates the disjoint and introduces UB which didn't previously exist right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agh good catch, alive2 does indeed confirm this: https://alive2.llvm.org/ce/z/W9F5b8. Will fix
| defvar wti = vtiToWti.Wti; | ||
| let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, | ||
| GetVTypePredicates<wti>.Predicates) in { | ||
| def : Pat<(wti.Vector (extop (vti.Vector (or_disjoint vti.RegClass:$rs2, vti.RegClass:$rs1)))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not include the splatpat one?
I think you can remove your custom multiclass, and do:
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add_like, sext_oneuse, "PseudoVWADD">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add_like, zext_oneuse, "PseudoVWADDU">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add_like, anyext_oneuse, "PseudoVWADDU">;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I've added a .vx pattern too now.
Unless I'm missing something I don't think I can reuse the VPatWidenBinarySDNode_VV_VX_WV_WX multiclass because it's a different pattern. VPatWidenOrDisjoint_VV matches ext (or a b),VPatWidenBinarySDNode_VV_VX_WV_WX matches or (ext a), (ext b)
preames
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, zext_oneuse, "PseudoVWADDU">; | ||
| defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, anyext_oneuse, "PseudoVWADDU">; | ||
|
|
||
| // DAGCombiner::hoistLogicOpWithSameOpcodeHands may hoist disjoint ors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have saying this - but remember this only covers scalable types, and that you probably need a follow up patch for the fixed length variant.
tclin914
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This handles combining fixed-length disjoint ors to vwadd[u].wv, as was done for scalable vectors in llvm#86929. vwadd[u].vv patterns need to be handled separately with a pattern in a separate patch due to the extends being sunk, see llvm#136716.
…u].v{v,x}
This is the fixed-length equivalent of llvm#136716.
The pattern we need to match is ({s,z}ext_vl (or_vl disjoint a, b)). This only allows or_vls with an undef passthru, which allows us to ignore its mask and vl and just take it from the {s,z}ext_vl.
A riscv_or_vl_is_add_oneuse PatFrag is added to mirror or_is_add in RISCVInstrInfo.td.
…u].v{v,x} (#136824)
This is the fixed-length equivalent of #136716.
The pattern we need to match is ({s,z}ext_vl (or_vl disjoint a, b)).
This only allows or_vls with an undef passthru, which allows us to
ignore its mask and vl and just take it from the {s,z}ext_vl.
A riscv_or_vl_is_add_oneuse PatFrag is added to mirror or_is_add in
RISCVInstrInfo.td.
DAGCombiner::hoistLogicOpWithSameOpcodeHands will hoist
(or disjoint (ext a), (ext b)) -> (ext (or disjoint a, b))
So this adds patterns to match vwadd[u].v{v,x} in this case.
We have to teach the combine to preserve the disjoint flag.
…vm#136820) This handles combining fixed-length disjoint ors to vwadd[u].wv, as was done for scalable vectors in llvm#86929. vwadd[u].vv patterns need to be handled separately with a pattern in a separate patch due to the extends being sunk, see llvm#136716.
…u].v{v,x} (llvm#136824)
This is the fixed-length equivalent of llvm#136716.
The pattern we need to match is ({s,z}ext_vl (or_vl disjoint a, b)).
This only allows or_vls with an undef passthru, which allows us to
ignore its mask and vl and just take it from the {s,z}ext_vl.
A riscv_or_vl_is_add_oneuse PatFrag is added to mirror or_is_add in
RISCVInstrInfo.td.
DAGCombiner::hoistLogicOpWithSameOpcodeHands will hoist
(or disjoint (ext a), (ext b)) -> (ext (or disjoint a, b))
So this adds a pattern to match vwadd[u].vv in this case.
We have to teach the combine to preserve the disjoint flag, and add a generic PatFrag for a disjoint or.
This is meant to be a follow up to #136677 which would allow us to remove the target hook added there.