Skip to content
86 changes: 34 additions & 52 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,8 +1073,8 @@ class LowerMatrixIntrinsics {
VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp, SI);
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
VisitCallInst(CInst);
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
VisitIntrinsicInst(Intr, SI);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Expand Down Expand Up @@ -1120,11 +1120,8 @@ class LowerMatrixIntrinsics {
}

/// Replace intrinsic calls.
void VisitCallInst(CallInst *Inst) {
assert(Inst->getCalledFunction() &&
Inst->getCalledFunction()->isIntrinsic());

switch (Inst->getCalledFunction()->getIntrinsicID()) {
void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
LowerMultiply(Inst);
break;
Expand All @@ -1138,8 +1135,36 @@ class LowerMatrixIntrinsics {
LowerColumnMajorStore(Inst);
break;
case Intrinsic::abs:
case Intrinsic::fabs:
return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
case Intrinsic::fabs: {
IRBuilder<> Builder(Inst);

MatrixTy Result;

MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors()) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
break;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
break;
default:
llvm_unreachable("unexpected intrinsic");
}
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return;
}
default:
llvm_unreachable(
"only intrinsics supporting shape info should be seen here");
Expand Down Expand Up @@ -2189,49 +2214,6 @@ class LowerMatrixIntrinsics {
Builder);
}

/// Lower uniform shape intrinsics, if shape information is available.
bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

MatrixTy Result;

switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs: {
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors())
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
break;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
break;
default:
llvm_unreachable("unexpected intrinsic");
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return true;
}
default:
llvm_unreachable("unexpected intrinsic");
}
}

/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
Expand Down
Loading