-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[WebAssembly] Add extra pattern for dot #151775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-webassembly Author: Jasmine Tang (badumbatish) ChangesFixes #50154 Full diff: https://github.com/llvm/llvm-project/pull/151775.diff 2 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);
+ // Combine add with vector shuffle of muls to dots
+ setTargetDAGCombine(ISD::ADD);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::ADD);
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+ if (VT != MVT::v4i32)
+ return SDValue();
+
+ auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+ if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+ return SDValue();
+ if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+ return SDValue();
+ return V;
+ };
+ auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+ auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+ // two SDValues must be muls
+ if (!ShuffleA || !ShuffleB)
+ return SDValue();
+
+ if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+ ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+ return SDValue();
+
+ auto IsMulExtend =
+ [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+ if (V.getOpcode() != ISD::MUL)
+ return {};
+
+ auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+ if (V0.getOpcode() != I || V1.getOpcode() != I)
+ return {};
+ return {V0.getOperand(0), V1.getOperand(0)};
+ };
+
+ auto [LowA, LowB] =
+ IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+ auto [HighA, HighB] =
+ IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+ if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+ return SDValue();
+
+ return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
+ case ISD::ADD:
+ return performAddCombine(N, DCI.DAG);
}
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..7ac49794491a1
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,21 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK: .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.dot_i16x8_s
+; CHECK-NEXT: # fallthrough-return
+ %sext1 = sext <8 x i16> %a to <8 x i32>
+ %sext2 = sext <8 x i16> %b to <8 x i32>
+ %mul = mul nsw <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %res = add <4 x i32> %shuffle1, %shuffle2
+ ret <4 x i32> %res
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is it possible to do this as a tablegen pattern? Not that it's necessarily the right thing to do, just wondering if it's easy to do or not!
I actually didn't think that tablegen can handle identical arguments in a pattern, i just tried it out just now and i think i might be able to make it work |
I'm assuming the 'illegal' types will make this more difficult in tablegen? This approach looks good to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR title needs updated to reflect that this isn't a combine but adds a foldpattern
EDIT: Typo, sorry!
Good point. It looks like the old combine only operated on MVT::v8i16s feeding into MVT::v4i32 adds, but I presume we also want to handle i8s that are sext'd to i32, e.g: define <4 x i32> @f(<8 x i8> %a, <8 x i8> %b) {
%sext1 = sext <8 x i8> %a to <8 x i32>
%sext2 = sext <8 x i8> %b to <8 x i32>
%mul = mul nsw <8 x i32> %sext1, %sext2
%shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
%shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
%res = add <4 x i32> %shuffle1, %shuffle2
ret <4 x i32> %res
} And from a quick check it looks the multiply is hoisted into the narrower type:
We could add a second tablegen pattern for Or we could try and do this again as a dagcombine, but it looks like the multiply is already hoisted by the time the types are legalized so we would have to handle the two different patterns anyway:
|
godbolt for this https://godbolt.org/z/Y1rrnW5h3. The IR at the isel phase looks a bit different. Would you want me to try that in this PR as well? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with the PR title updated
I think we can handle the zext from i8 case in a separate PR if the pattern ends up being somewhat different anyway, but I'll defer to @sparker-arm on this! (I'm not strongly opinionated on this as to if we want to do this as a combine or tablegen pattern)
hi @sparker-arm, i'm looking to add the somewhat same pattern to relaxed simd dot, can you confirm if this approach is good for merge and for the subsequent pr? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't know you were waiting on me, please feel free to shout sooner next time!
no worries, i'll keep that in mind next time, ty for the reviews! |
Fixes #50154