Skip to content

Commit 45ad537

Browse files
committed
[WebAssembly] Add simd support for memcmp
1 parent 1381ad4 commit 45ad537

File tree

3 files changed

+68
-17
lines changed

3 files changed

+68
-17
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,15 +3383,71 @@ static SDValue TryMatchTrue(SDNode *N, EVT VecVT, SelectionDAG &DAG) {
33833383
return DAG.getZExtOrTrunc(Ret, DL, N->getValueType(0));
33843384
}
33853385

3386+
static SDValue
3387+
combineVectorSizedSetCCEquality(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
3388+
const WebAssemblySubtarget *Subtarget) {
3389+
3390+
SDLoc DL(N);
3391+
SDValue X = N->getOperand(0);
3392+
SDValue Y = N->getOperand(1);
3393+
EVT VT = N->getValueType(0);
3394+
EVT OpVT = X.getValueType();
3395+
3396+
ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
3397+
SelectionDAG &DAG = DCI.DAG;
3398+
// We're looking for an oversized integer equality comparison.
3399+
if (!OpVT.isScalarInteger() || !OpVT.isByteSized() || OpVT != MVT::i128 ||
3400+
!Subtarget->hasSIMD128())
3401+
return SDValue();
3402+
3403+
// Don't perform this combine if constructing the vector will be expensive.
3404+
auto IsVectorBitCastCheap = [](SDValue X) {
3405+
X = peekThroughBitcasts(X);
3406+
return isa<ConstantSDNode>(X) || X.getOpcode() == ISD::LOAD;
3407+
};
3408+
3409+
if (!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y))
3410+
return SDValue();
3411+
3412+
// TODO: Not sure what's the purpose of this? I'm keeping here since RISCV has
3413+
// it
3414+
if (DCI.DAG.getMachineFunction().getFunction().hasFnAttribute(
3415+
Attribute::NoImplicitFloat))
3416+
return SDValue();
3417+
3418+
unsigned OpSize = OpVT.getSizeInBits();
3419+
unsigned VecSize = OpSize / 8;
3420+
3421+
EVT VecVT = EVT::getVectorVT(*DCI.DAG.getContext(), MVT::i8, VecSize);
3422+
EVT CmpVT = EVT::getVectorVT(*DCI.DAG.getContext(), MVT::i8, VecSize);
3423+
3424+
SDValue VecX = DAG.getBitcast(VecVT, X);
3425+
SDValue VecY = DAG.getBitcast(VecVT, Y);
3426+
3427+
SDValue Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, CC);
3428+
3429+
SDValue AllTrue = DAG.getZExtOrTrunc(
3430+
DAG.getNode(
3431+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
3432+
{DAG.getConstant(Intrinsic::wasm_alltrue, DL, MVT::i32), Cmp}),
3433+
DL, MVT::i1);
3434+
3435+
return DAG.getSetCC(DL, VT, AllTrue, DAG.getConstant(0, DL, MVT::i1), CC);
3436+
}
3437+
33863438
static SDValue performSETCCCombine(SDNode *N,
3387-
TargetLowering::DAGCombinerInfo &DCI) {
3439+
TargetLowering::DAGCombinerInfo &DCI,
3440+
const WebAssemblySubtarget *Subtarget) {
33883441
if (!DCI.isBeforeLegalize())
33893442
return SDValue();
33903443

33913444
EVT VT = N->getValueType(0);
33923445
if (!VT.isScalarInteger())
33933446
return SDValue();
33943447

3448+
if (SDValue V = combineVectorSizedSetCCEquality(N, DCI, Subtarget))
3449+
return V;
3450+
33953451
SDValue LHS = N->getOperand(0);
33963452
if (LHS->getOpcode() != ISD::BITCAST)
33973453
return SDValue();
@@ -3532,7 +3588,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
35323588
case ISD::BITCAST:
35333589
return performBitcastCombine(N, DCI);
35343590
case ISD::SETCC:
3535-
return performSETCCCombine(N, DCI);
3591+
return performSETCCCombine(N, DCI, Subtarget);
35363592
case ISD::VECTOR_SHUFFLE:
35373593
return performVECTOR_SHUFFLECombine(N, DCI);
35383594
case ISD::SIGN_EXTEND:

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ WebAssemblyTTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
147147

148148
Options.AllowOverlappingLoads = true;
149149

150-
// TODO: Teach WebAssembly backend about load v128.
150+
if (ST->hasSIMD128())
151+
Options.LoadSizes.push_back(16);
151152

152153
Options.LoadSizes.append({8, 4, 2, 1});
153154
Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);

llvm/test/CodeGen/WebAssembly/memcmp-expand.ll

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc < %s -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers | FileCheck %s
2+
; RUN: llc < %s -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
33

44
target triple = "wasm32-unknown-unknown"
55

@@ -132,19 +132,13 @@ define i1 @memcmp_expand_16(ptr %a, ptr %b) {
132132
; CHECK-LABEL: memcmp_expand_16:
133133
; CHECK: .functype memcmp_expand_16 (i32, i32) -> (i32)
134134
; CHECK-NEXT: # %bb.0:
135-
; CHECK-NEXT: i64.load $push7=, 0($0):p2align=0
136-
; CHECK-NEXT: i64.load $push6=, 0($1):p2align=0
137-
; CHECK-NEXT: i64.xor $push8=, $pop7, $pop6
138-
; CHECK-NEXT: i32.const $push0=, 8
139-
; CHECK-NEXT: i32.add $push3=, $0, $pop0
140-
; CHECK-NEXT: i64.load $push4=, 0($pop3):p2align=0
141-
; CHECK-NEXT: i32.const $push11=, 8
142-
; CHECK-NEXT: i32.add $push1=, $1, $pop11
143-
; CHECK-NEXT: i64.load $push2=, 0($pop1):p2align=0
144-
; CHECK-NEXT: i64.xor $push5=, $pop4, $pop2
145-
; CHECK-NEXT: i64.or $push9=, $pop8, $pop5
146-
; CHECK-NEXT: i64.eqz $push10=, $pop9
147-
; CHECK-NEXT: return $pop10
135+
; CHECK-NEXT: v128.load $push1=, 0($0):p2align=0
136+
; CHECK-NEXT: v128.load $push0=, 0($1):p2align=0
137+
; CHECK-NEXT: i8x16.eq $push2=, $pop1, $pop0
138+
; CHECK-NEXT: i8x16.all_true $push3=, $pop2
139+
; CHECK-NEXT: i32.const $push4=, 1
140+
; CHECK-NEXT: i32.xor $push5=, $pop3, $pop4
141+
; CHECK-NEXT: return $pop5
148142
%cmp_16 = call i32 @memcmp(ptr %a, ptr %b, i32 16)
149143
%res = icmp eq i32 %cmp_16, 0
150144
ret i1 %res

0 commit comments

Comments
 (0)