Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 153 additions & 1 deletion llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
Expand All @@ -41,6 +42,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -325,6 +327,25 @@ computeShapeInfoForInst(Instruction *I,
return OpShape->second;
}

if (auto *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
case Intrinsic::vector_reduce_fadd:
case Intrinsic::vector_reduce_fmul:
case Intrinsic::vector_reduce_fmax:
case Intrinsic::vector_reduce_fmaximum:
case Intrinsic::vector_reduce_fmin:
case Intrinsic::vector_reduce_fminimum:
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_and:
case Intrinsic::vector_reduce_mul:
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_xor:
return ShapeInfo(1, 1);
default:
break;
}
}

if (isUniformShape(I) || isa<SelectInst>(I)) {
auto Ops = I->operands();
auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
Expand Down Expand Up @@ -468,7 +489,7 @@ class LowerMatrixIntrinsics {
return make_range(Vectors.begin(), Vectors.end());
}

iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
iterator_range<SmallVector<Value *, 8>::const_iterator> vectors() const {
return make_range(Vectors.begin(), Vectors.end());
}

Expand Down Expand Up @@ -701,7 +722,31 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_transpose:
case Intrinsic::matrix_column_major_load:
case Intrinsic::matrix_column_major_store:
case Intrinsic::vector_reduce_fmax:
case Intrinsic::vector_reduce_fmaximum:
case Intrinsic::vector_reduce_fmin:
case Intrinsic::vector_reduce_fminimum:
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_and:
case Intrinsic::vector_reduce_mul:
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_xor:
return true;
case Intrinsic::vector_reduce_fadd:
case Intrinsic::vector_reduce_fmul: {
FastMathFlags FMF = getFastMathFlags(Inst);
if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
return false;

if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>(
m_Unless(m_AnyZeroFP()), m_Value())))
return false;

if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>(
m_Unless(m_FPOne()), m_Value())))
return false;
return true;
}
default:
return isUniformShape(II);
}
Expand Down Expand Up @@ -1268,6 +1313,113 @@ class LowerMatrixIntrinsics {
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors());
}
case Intrinsic::vector_reduce_fadd: {
Builder.setFastMathFlags(getFastMathFlags(Inst));
auto *I = Inst2ColumnMatrix.find(Inst->getOperand(1));
assert(I != Inst2ColumnMatrix.end());
const MatrixTy &M = I->second;

Value *Start = Inst->getOperand(0);
Value *ResultV = Builder.CreateVectorSplat(
ElementCount::getFixed(M.getStride()), Start);
for (auto &Vector : M.vectors())
ResultV = Builder.CreateFAdd(ResultV, Vector);

Value *Result = Builder.CreateFAddReduce(Start, ResultV);
Inst->replaceAllUsesWith(Result);
Result->takeName(Inst);
return MatrixTy{Result};
} break;
case Intrinsic::vector_reduce_fmul: {
Builder.setFastMathFlags(getFastMathFlags(Inst));
auto *I = Inst2ColumnMatrix.find(Inst->getOperand(1));
assert(I != Inst2ColumnMatrix.end());
const MatrixTy &M = I->second;

Value *Start = Inst->getOperand(0);
Value *ResultV = Builder.CreateVectorSplat(
ElementCount::getFixed(M.getStride()), Start);
for (auto &Vector : M.vectors())
ResultV = Builder.CreateFMul(ResultV, Vector);

Value *Result = Builder.CreateFMulReduce(Start, ResultV);
Inst->replaceAllUsesWith(Result);
Result->takeName(Inst);
return MatrixTy{Result};
} break;
case Intrinsic::vector_reduce_fmax:
case Intrinsic::vector_reduce_fmaximum:
case Intrinsic::vector_reduce_fmin:
case Intrinsic::vector_reduce_fminimum:
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_and:
case Intrinsic::vector_reduce_mul:
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_xor: {
Builder.setFastMathFlags(getFastMathFlags(Inst));
auto *I = Inst2ColumnMatrix.find(Inst->getOperand(0));
assert(I != Inst2ColumnMatrix.end());
const MatrixTy &M = I->second;

auto CreateVReduce = [&](Value *LHS, Value *RHS) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::vector_reduce_add:
return Builder.CreateAdd(LHS, RHS);
case Intrinsic::vector_reduce_and:
return Builder.CreateAnd(LHS, RHS);
case Intrinsic::vector_reduce_fmax:
return Builder.CreateMaximum(LHS, RHS);
case Intrinsic::vector_reduce_fmaximum:
return Builder.CreateMaximumNum(LHS, RHS);
case Intrinsic::vector_reduce_fmin:
return Builder.CreateMinimum(LHS, RHS);
case Intrinsic::vector_reduce_fminimum:
return Builder.CreateMinimumNum(LHS, RHS);
case Intrinsic::vector_reduce_mul:
return Builder.CreateMul(LHS, RHS);
case Intrinsic::vector_reduce_or:
return Builder.CreateOr(LHS, RHS);
case Intrinsic::vector_reduce_xor:
return Builder.CreateXor(LHS, RHS);
default:
llvm_unreachable("unexpected intrinsic");
}
};

Value *ResultV = M.getVector(0);
for (auto &Vector : drop_begin(M.vectors()))
ResultV = CreateVReduce(ResultV, Vector);

auto CreateHReduce = [&](Value *V) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::vector_reduce_add:
return Builder.CreateAddReduce(V);
case Intrinsic::vector_reduce_and:
return Builder.CreateAndReduce(V);
case Intrinsic::vector_reduce_fmax:
return Builder.CreateFPMaxReduce(V);
case Intrinsic::vector_reduce_fmaximum:
return Builder.CreateFPMaximumReduce(V);
case Intrinsic::vector_reduce_fmin:
return Builder.CreateFPMinReduce(V);
case Intrinsic::vector_reduce_fminimum:
return Builder.CreateFPMinimumReduce(V);
case Intrinsic::vector_reduce_mul:
return Builder.CreateMulReduce(V);
case Intrinsic::vector_reduce_or:
return Builder.CreateOrReduce(V);
case Intrinsic::vector_reduce_xor:
return Builder.CreateXorReduce(V);
default:
llvm_unreachable("unexpected intrinsic");
}
};

Value *Result = CreateHReduce(ResultV);
Inst->replaceAllUsesWith(Result);
Result->takeName(Inst);
return MatrixTy{Result};
} break;
default:
break;
}
Expand Down
Loading
Loading