-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Matrix] Lower vector reductions using shape info #142055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Jon Roelofs (jroelofs) ChangesWhen possible, this avoids a bunch of shuffles in & out of the flattened Patch is 21.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142055.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..4f997f2133527 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,16 +30,19 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -1101,6 +1104,7 @@ class LowerMatrixIntrinsics {
if (!PoisonedInsts.empty()) {
// If we didn't remove all poisoned instructions, it's a hard error.
dbgs() << "Poisoned but present instructions:\n";
+ Func.dump();
for (auto *I : PoisonedInsts)
dbgs() << *I << "\n";
llvm_unreachable("Poisoned but instruction not removed");
@@ -1337,6 +1341,155 @@ class LowerMatrixIntrinsics {
return Builder.CreateAdd(Sum, Mul);
}
+ bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+ Value *Op0 = Inst->getOperand(0);
+ auto *VTy = cast<VectorType>(Op0->getType());
+
+ if (VTy->getElementCount().getKnownMinValue() < Index) {
+ Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+ Inst->eraseFromParent();
+ return true;
+ }
+
+ auto *I = Inst2ColumnMatrix.find(Op0);
+ if (I == Inst2ColumnMatrix.end())
+ return false;
+
+ const MatrixTy &M = I->second;
+
+ IRBuilder<> Builder(Inst);
+ Inst->setOperand(0, M.getVector(Index / M.getStride()));
+ Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+ return true;
+ }
+
+ bool VisitReduce(IntrinsicInst *Inst) {
+ FastMathFlags FMF = getFastMathFlags(Inst);
+
+ if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+ return false;
+
+ Value *Start = nullptr;
+ Value *Op = nullptr;
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_fadd:
+ case Intrinsic::vector_reduce_fmul:
+ Start = Inst->getOperand(0);
+ Op = Inst->getOperand(1);
+ break;
+ case Intrinsic::vector_reduce_fmax:
+ case Intrinsic::vector_reduce_fmaximum:
+ case Intrinsic::vector_reduce_fmin:
+ case Intrinsic::vector_reduce_fminimum:
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ Op = Inst->getOperand(0);
+ break;
+ default:
+ llvm_unreachable("unexpected intrinsic");
+ }
+
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_fadd: {
+ if (!match(Start, m_AnyZeroFP()))
+ return false;
+ } break;
+ case Intrinsic::vector_reduce_fmul: {
+ if (!match(Start, m_FPOne()))
+ return false;
+ } break;
+ default:
+ break;
+ }
+
+ auto *I = Inst2ColumnMatrix.find(Op);
+ if (I == Inst2ColumnMatrix.end())
+ return false;
+
+ IRBuilder<> Builder(Inst);
+
+ const MatrixTy &M = I->second;
+
+ auto CreateVReduce = [&](Value *LHS, Value *RHS) {
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_add:
+ return Builder.CreateAdd(LHS, RHS);
+ case Intrinsic::vector_reduce_and:
+ return Builder.CreateAnd(LHS, RHS);
+ case Intrinsic::vector_reduce_fadd:
+ return Builder.CreateFAdd(LHS, RHS);
+ case Intrinsic::vector_reduce_fmax:
+ return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fmaximum:
+ return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fmin:
+ return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fminimum:
+ return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fmul:
+ return Builder.CreateFMul(LHS, RHS);
+ case Intrinsic::vector_reduce_mul:
+ return Builder.CreateMul(LHS, RHS);
+ case Intrinsic::vector_reduce_or:
+ return Builder.CreateOr(LHS, RHS);
+ case Intrinsic::vector_reduce_xor:
+ return Builder.CreateXor(LHS, RHS);
+ default:
+ llvm_unreachable("unexpected intrinsic");
+ }
+ };
+
+ Value *ResultV;
+ if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
+ Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
+ ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()), Start);
+ for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
+ ResultV = CreateVReduce(ResultV, M.getVector(VI));
+ } else {
+ ResultV = M.getVector(0);
+ for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
+ ResultV = CreateVReduce(ResultV, M.getVector(VI));
+ }
+
+ auto CreateHReduce = [&](Value *V) {
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_add:
+ return Builder.CreateAddReduce(V);
+ case Intrinsic::vector_reduce_and:
+ return Builder.CreateAndReduce(V);
+ case Intrinsic::vector_reduce_fadd:
+ return Builder.CreateFAddReduce(Start, V);
+ case Intrinsic::vector_reduce_fmax:
+ return Builder.CreateFPMaxReduce(V);
+ case Intrinsic::vector_reduce_fmaximum:
+ return Builder.CreateFPMaximumReduce(V);
+ case Intrinsic::vector_reduce_fmin:
+ return Builder.CreateFPMinReduce(V);
+ case Intrinsic::vector_reduce_fminimum:
+ return Builder.CreateFPMinimumReduce(V);
+ case Intrinsic::vector_reduce_fmul:
+ return Builder.CreateFMulReduce(Start, V);
+ case Intrinsic::vector_reduce_mul:
+ return Builder.CreateMulReduce(V);
+ case Intrinsic::vector_reduce_or:
+ return Builder.CreateOrReduce(V);
+ case Intrinsic::vector_reduce_xor:
+ return Builder.CreateXorReduce(V);
+ default:
+ llvm_unreachable("unexpected intrinsic");
+ }
+ };
+
+ Value *Result = CreateHReduce(ResultV);
+ Inst->replaceAllUsesWith(Result);
+ Result->takeName(Inst);
+ Inst->eraseFromParent();
+ return true;
+ }
+
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
/// users with shape information, there's nothing to do: they will use the
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1504,39 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(Inst);
Value *Flattened = nullptr;
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (!ShapeMap.contains(U.getUser())) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
+ if (ShapeMap.contains(U.getUser()))
+ continue;
+
+ Value *Op1;
+ uint64_t Index;
+ if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+ if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+ continue;
+
+ if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
+ switch (Intr->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_fadd:
+ case Intrinsic::vector_reduce_fmax:
+ case Intrinsic::vector_reduce_fmaximum:
+ case Intrinsic::vector_reduce_fmin:
+ case Intrinsic::vector_reduce_fminimum:
+ case Intrinsic::vector_reduce_fmul:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ if (VisitReduce(Intr))
+ continue;
+ break;
+ default:
+ break;
+ }
}
+
+ if (!Flattened)
+ Flattened = Matrix.embedInVector(Builder);
+ U.set(Flattened);
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..db5444ca036ae
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,41 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
+; CHECK-NEXT: ret float [[EXTRACT]]
+;
+ %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+ %extract = extractelement <4 x float> %inv, i32 3
+ ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: ret float poison
+;
+ %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+ %extract = extractelement <4 x float> %inv, i32 5
+ ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT: ret float [[EXTRACT]]
+;
+ %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+ %extract = extractelement <4 x float> %inv, i32 %idx
+ ret float %extract
+}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
new file mode 100644
index 0000000000000..41f65e01fec79
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define i32 @reduce_add(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_and(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_and(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.and(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_or(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_or(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = or <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.or(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_mul(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_mul(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = mul <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.mul(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_xor(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_xor(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = xor <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.xor(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define float @reduce_fadd(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fadd_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_weirdstart(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v8f32(float 1.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fadd(float 1., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmul_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmul(float 1., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_weirdstart(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fmul.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmul(float 0., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.maximumnu...
[truncated]
|
e1368d5 to
897768b
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
897768b to
51e0487
Compare
|
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be driven by a cost function? The new lowering seems strictly worse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've rebased the patch stack to pseudo pre-land the test, so the diff is more obvious.
aff9068#diff-6a09d32782efca5b9899b5ed357c5befd70d65f461a2ae64284f2504975f8948
Hm, yeah, that is bad: https://llvm.godbolt.org/z/7EbfGer3E
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(It's not completely now, we already get the similarly bad bevhavior for other ops, but with the reductions it's probably even worth, so it might be worth tackling before going further down that road)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, wait, it's better, at least on cyclone (the only thing we have a scheduling model for at the moment)
old: https://llvm.godbolt.org/z/foash7aMq
new: https://llvm.godbolt.org/z/99aTd41Mr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any ideas on what the cost function would be, if we had one?
e1cb5a1 to
aff9068
Compare
When possible, this avoids a bunch of shuffles in & out of the flattened
layout.