Skip to content
Merged
58 changes: 46 additions & 12 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
if (I->isBinaryOp())
return true;

if (auto *II = dyn_cast<IntrinsicInst>(V))
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs:
return true;
default:
return false;
}

switch (I->getOpcode()) {
case Instruction::FNeg:
return true;
Expand Down Expand Up @@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
return true;
default:
return false;
return isUniformShape(II);
}
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}
Expand Down Expand Up @@ -1064,8 +1073,8 @@ class LowerMatrixIntrinsics {
VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp, SI);
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
VisitCallInst(CInst);
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
VisitIntrinsicInst(Intr, SI);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Expand Down Expand Up @@ -1111,23 +1120,48 @@ class LowerMatrixIntrinsics {
}

/// Replace intrinsic calls.
void VisitCallInst(CallInst *Inst) {
assert(Inst->getCalledFunction() &&
Inst->getCalledFunction()->isIntrinsic());

switch (Inst->getCalledFunction()->getIntrinsicID()) {
void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
LowerMultiply(Inst);
break;
return;
case Intrinsic::matrix_transpose:
LowerTranspose(Inst);
break;
return;
case Intrinsic::matrix_column_major_load:
LowerColumnMajorLoad(Inst);
break;
return;
case Intrinsic::matrix_column_major_store:
LowerColumnMajorStore(Inst);
break;
return;
case Intrinsic::abs:
case Intrinsic::fabs: {
IRBuilder<> Builder(Inst);
MatrixTy Result;
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors()) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
continue;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
continue;
default:
llvm_unreachable("unexpected intrinsic");
}
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return;
}
default:
llvm_unreachable(
"only intrinsics supporting shape info should be seen here");
Expand Down
Loading
Loading