-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[WebAssembly] Combine i128 to v16i8 for setcc & expand memcmp for 16 byte loads with simd128 #149461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Edit: this is resolved.I'm trying out this PR but I think I encountered a blocker. The issue pops up with this reduced test case from the test case I'm not sure how to reconcile this? ... |
8432eea
to
45ad537
Compare
alright, with Luke's pointer from this PR #114517, I've tried a different approach: doesn't allow i128 to be legal everywhere but only on load via |
@llvm/pr-subscribers-backend-webassembly Author: Jasmine Tang (badumbatish) ChangesFixes #149230 Full diff: https://github.com/llvm/llvm-project/pull/149461.diff 3 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..ee16f7bf9133d 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -3383,8 +3383,61 @@ static SDValue TryMatchTrue(SDNode *N, EVT VecVT, SelectionDAG &DAG) {
return DAG.getZExtOrTrunc(Ret, DL, N->getValueType(0));
}
+static SDValue
+combineVectorSizedSetCCEquality(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const WebAssemblySubtarget *Subtarget) {
+
+ SDLoc DL(N);
+ SDValue X = N->getOperand(0);
+ SDValue Y = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ EVT OpVT = X.getValueType();
+
+ ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
+ SelectionDAG &DAG = DCI.DAG;
+ // We're looking for an oversized integer equality comparison.
+ if (!OpVT.isScalarInteger() || !OpVT.isByteSized() || OpVT != MVT::i128 ||
+ !Subtarget->hasSIMD128())
+ return SDValue();
+
+ // Don't perform this combine if constructing the vector will be expensive.
+ auto IsVectorBitCastCheap = [](SDValue X) {
+ X = peekThroughBitcasts(X);
+ return isa<ConstantSDNode>(X) || X.getOpcode() == ISD::LOAD;
+ };
+
+ if (!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y))
+ return SDValue();
+
+ // TODO: Not sure what's the purpose of this? I'm keeping here since RISCV has
+ // it
+ if (DCI.DAG.getMachineFunction().getFunction().hasFnAttribute(
+ Attribute::NoImplicitFloat))
+ return SDValue();
+
+ unsigned OpSize = OpVT.getSizeInBits();
+ unsigned VecSize = OpSize / 8;
+
+ EVT VecVT = EVT::getVectorVT(*DCI.DAG.getContext(), MVT::i8, VecSize);
+ EVT CmpVT = EVT::getVectorVT(*DCI.DAG.getContext(), MVT::i8, VecSize);
+
+ SDValue VecX = DAG.getBitcast(VecVT, X);
+ SDValue VecY = DAG.getBitcast(VecVT, Y);
+
+ SDValue Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, CC);
+
+ SDValue AllTrue = DAG.getZExtOrTrunc(
+ DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
+ {DAG.getConstant(Intrinsic::wasm_alltrue, DL, MVT::i32), Cmp}),
+ DL, MVT::i1);
+
+ return DAG.getSetCC(DL, VT, AllTrue, DAG.getConstant(0, DL, MVT::i1), CC);
+}
+
static SDValue performSETCCCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ const WebAssemblySubtarget *Subtarget) {
if (!DCI.isBeforeLegalize())
return SDValue();
@@ -3392,6 +3445,9 @@ static SDValue performSETCCCombine(SDNode *N,
if (!VT.isScalarInteger())
return SDValue();
+ if (SDValue V = combineVectorSizedSetCCEquality(N, DCI, Subtarget))
+ return V;
+
SDValue LHS = N->getOperand(0);
if (LHS->getOpcode() != ISD::BITCAST)
return SDValue();
@@ -3532,7 +3588,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::BITCAST:
return performBitcastCombine(N, DCI);
case ISD::SETCC:
- return performSETCCCombine(N, DCI);
+ return performSETCCCombine(N, DCI, Subtarget);
case ISD::VECTOR_SHUFFLE:
return performVECTOR_SHUFFLECombine(N, DCI);
case ISD::SIGN_EXTEND:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 52e706514226b..08fb7586d215e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -147,7 +147,8 @@ WebAssemblyTTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
Options.AllowOverlappingLoads = true;
- // TODO: Teach WebAssembly backend about load v128.
+ if (ST->hasSIMD128())
+ Options.LoadSizes.push_back(16);
Options.LoadSizes.append({8, 4, 2, 1});
Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
diff --git a/llvm/test/CodeGen/WebAssembly/memcmp-expand.ll b/llvm/test/CodeGen/WebAssembly/memcmp-expand.ll
index 8030438645f82..c6df6b50693fa 100644
--- a/llvm/test/CodeGen/WebAssembly/memcmp-expand.ll
+++ b/llvm/test/CodeGen/WebAssembly/memcmp-expand.ll
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers | FileCheck %s
+; RUN: llc < %s -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
target triple = "wasm32-unknown-unknown"
@@ -132,19 +132,13 @@ define i1 @memcmp_expand_16(ptr %a, ptr %b) {
; CHECK-LABEL: memcmp_expand_16:
; CHECK: .functype memcmp_expand_16 (i32, i32) -> (i32)
; CHECK-NEXT: # %bb.0:
-; CHECK-NEXT: i64.load $push7=, 0($0):p2align=0
-; CHECK-NEXT: i64.load $push6=, 0($1):p2align=0
-; CHECK-NEXT: i64.xor $push8=, $pop7, $pop6
-; CHECK-NEXT: i32.const $push0=, 8
-; CHECK-NEXT: i32.add $push3=, $0, $pop0
-; CHECK-NEXT: i64.load $push4=, 0($pop3):p2align=0
-; CHECK-NEXT: i32.const $push11=, 8
-; CHECK-NEXT: i32.add $push1=, $1, $pop11
-; CHECK-NEXT: i64.load $push2=, 0($pop1):p2align=0
-; CHECK-NEXT: i64.xor $push5=, $pop4, $pop2
-; CHECK-NEXT: i64.or $push9=, $pop8, $pop5
-; CHECK-NEXT: i64.eqz $push10=, $pop9
-; CHECK-NEXT: return $pop10
+; CHECK-NEXT: v128.load $push1=, 0($0):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($1):p2align=0
+; CHECK-NEXT: i8x16.eq $push2=, $pop1, $pop0
+; CHECK-NEXT: i8x16.all_true $push3=, $pop2
+; CHECK-NEXT: i32.const $push4=, 1
+; CHECK-NEXT: i32.xor $push5=, $pop3, $pop4
+; CHECK-NEXT: return $pop5
%cmp_16 = call i32 @memcmp(ptr %a, ptr %b, i32 16)
%res = icmp eq i32 %cmp_16, 0
ret i1 %res
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR title should probably be something like "Expand memcmp for 16 byte loads with simd128", since this PR also enables it in WebAsssemblyTargetTransformInfo.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Just for the PR title, I would say Combine i128 to v16i8 for setcc
since it's technically a combine, not legalization.
And make sure to flesh out the PR description with a few sentences about how ExpandMemcmp can expand larger 128 bit loads, but they're emitted as i128s and we need to combine them into v16i8 types for efficient lowering.
It looks like this change has caused a test failure on Emscripten's test suite: the first memcmp in the neon test. |
Fixes #149230
Previously, even with simd enabled via
-mattr=+simd128
, the compiler cannot utilize v128 to optimize loads and setcc of i128, instead legalizing it to consecutive i64s.This PR then adds support for setcc of i128 by converting them to v16i8's anytrue and alltrue; consequently, this benefits memcmp of 16 bytes or more (when simd128 is present).
The check for enabling this optimization is if the comparison operand is either a load or an integer in i128, with the comparison code being either
EQ | NE
, withoutNoImplicitFloat
function flag.Inspiration taken from RISCV's isel lowering.