Skip to content

Commit 9c6bb18

Browse files
authored
[WebAssembly] Constant fold wasm.dot (#149619)
Constant fold wasm.dot of constant vectors/splats. Test case added in `llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll` Related to #55933
1 parent 34aed0e commit 9c6bb18

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
16591659
case Intrinsic::aarch64_sve_convert_from_svbool:
16601660
case Intrinsic::wasm_alltrue:
16611661
case Intrinsic::wasm_anytrue:
1662+
case Intrinsic::wasm_dot:
16621663
// WebAssembly float semantics are always known
16631664
case Intrinsic::wasm_trunc_signed:
16641665
case Intrinsic::wasm_trunc_unsigned:
@@ -3989,6 +3990,30 @@ static Constant *ConstantFoldFixedVectorCall(
39893990
}
39903991
return ConstantVector::get(Result);
39913992
}
3993+
case Intrinsic::wasm_dot: {
3994+
unsigned NumElements =
3995+
cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
3996+
3997+
assert(NumElements == 8 && Result.size() == 4 &&
3998+
"wasm dot takes i16x8 and produces i32x4");
3999+
assert(Ty->isIntegerTy());
4000+
int32_t MulVector[8];
4001+
4002+
for (unsigned I = 0; I < NumElements; ++I) {
4003+
ConstantInt *Elt0 =
4004+
cast<ConstantInt>(Operands[0]->getAggregateElement(I));
4005+
ConstantInt *Elt1 =
4006+
cast<ConstantInt>(Operands[1]->getAggregateElement(I));
4007+
4008+
MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue();
4009+
}
4010+
for (unsigned I = 0; I < Result.size(); I++) {
4011+
int32_t IAdd = MulVector[I * 2] + MulVector[I * 2 + 1];
4012+
Result[I] = ConstantInt::get(Ty, IAdd);
4013+
}
4014+
4015+
return ConstantVector::get(Result);
4016+
}
39924017
default:
39934018
break;
39944019
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
3+
; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
4+
5+
; Test that intrinsics wasm dot call are constant folded
6+
7+
target triple = "wasm32-unknown-unknown"
8+
9+
10+
define <4 x i32> @dot_zero() {
11+
; CHECK-LABEL: define <4 x i32> @dot_zero() {
12+
; CHECK-NEXT: ret <4 x i32> zeroinitializer
13+
;
14+
%res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
15+
ret <4 x i32> %res
16+
}
17+
18+
; a = 1 2 3 4 5 6 7 8
19+
; b = 1 2 3 4 5 6 7 8
20+
; k1|k2 = a * b = 1 4 9 16 25 36 49 64
21+
; k1 + k2 = (1+4) | (9 + 16) | (25 + 36) | (49 + 64)
22+
; result = 5 | 25 | 61 | 113
23+
define <4 x i32> @dot_nonzero() {
24+
; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
25+
; CHECK-NEXT: ret <4 x i32> <i32 5, i32 25, i32 61, i32 113>
26+
;
27+
%res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
28+
ret <4 x i32> %res
29+
}
30+
31+
define <4 x i32> @dot_doubly_negative() {
32+
; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
33+
; CHECK-NEXT: ret <4 x i32> splat (i32 2)
34+
;
35+
%res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
36+
ret <4 x i32> %res
37+
}
38+
39+
; Tests that i16 max signed values fit in i32
40+
define <4 x i32> @dot_follow_modulo_spec_1() {
41+
; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_1() {
42+
; CHECK-NEXT: ret <4 x i32> <i32 2147352578, i32 0, i32 0, i32 0>
43+
;
44+
%res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
45+
ret <4 x i32> %res
46+
}
47+
48+
; Tests that i16 min signed values fit in i32
49+
define <4 x i32> @dot_follow_modulo_spec_2() {
50+
; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_2() {
51+
; CHECK-NEXT: ret <4 x i32> <i32 -2147483648, i32 0, i32 0, i32 0>
52+
;
53+
%res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -32768, i16 -32768, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 -32768, i16 -32768, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
54+
ret <4 x i32> %res
55+
}
56+

0 commit comments

Comments
 (0)