Skip to content
Merged
57 changes: 56 additions & 1 deletion llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
if (I->isBinaryOp())
return true;

if (auto *II = dyn_cast<IntrinsicInst>(V))
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs:
return true;
default:
return false;
}

switch (I->getOpcode()) {
case Instruction::FNeg:
return true;
Expand Down Expand Up @@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
return true;
default:
return false;
return isUniformShape(II);
}
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}
Expand Down Expand Up @@ -1124,6 +1133,9 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
LowerColumnMajorStore(Inst);
break;
case Intrinsic::abs:
case Intrinsic::fabs:
return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
default:
return false;
}
Expand Down Expand Up @@ -2194,6 +2206,49 @@ class LowerMatrixIntrinsics {
return true;
}

/// Lower uniform shape intrinsics, if shape information is available.
bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

MatrixTy Result;

switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs: {
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors())
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
break;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
break;
default:
llvm_unreachable("unexpected intrinsic");
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return true;
}
default:
llvm_unreachable("unexpected intrinsic");
}
}

/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
Expand Down
Loading
Loading