Skip to content
76 changes: 30 additions & 46 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1056,19 +1056,20 @@ class LowerMatrixIntrinsics {

IRBuilder<> Builder(Inst);

if (CallInst *CInst = dyn_cast<CallInst>(Inst))
Changed |= VisitCallInst(CInst);
const ShapeInfo &SI = ShapeMap.at(Inst);

Value *Op1;
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
VisitBinaryOperator(BinOp);
VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp);
VisitUnaryOperator(UnOp, SI);
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
VisitCallInst(CInst);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
else
continue;
Changed = true;
Expand Down Expand Up @@ -1109,10 +1110,10 @@ class LowerMatrixIntrinsics {
return Changed;
}

/// Replace intrinsic calls
bool VisitCallInst(CallInst *Inst) {
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
return false;
/// Replace intrinsic calls.
void VisitCallInst(CallInst *Inst) {
assert(Inst->getCalledFunction() &&
Inst->getCalledFunction()->isIntrinsic());

switch (Inst->getCalledFunction()->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
Expand All @@ -1128,9 +1129,9 @@ class LowerMatrixIntrinsics {
LowerColumnMajorStore(Inst);
break;
default:
return false;
llvm_unreachable(
"only intrinsics supporting shape info should be seen here");
}
return true;
}

/// Compute the alignment for a column/row \p Idx with \p Stride between them.
Expand Down Expand Up @@ -2107,48 +2108,36 @@ class LowerMatrixIntrinsics {
Builder);
}

/// Lower load instructions, if shape information is available.
void VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");
LowerLoad(Inst, Ptr, Inst->getAlign(),
Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
I->second);
/// Lower load instructions.
void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
IRBuilder<> &Builder) {
LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
Inst->isVolatile(), SI);
}

void VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");
void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
Value *Ptr, IRBuilder<> &Builder) {
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
I->second);
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
}

/// Lower binary operators, if shape information is available.
void VisitBinaryOperator(BinaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");

/// Lower binary operators.
void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

MatrixTy Result;
MatrixTy A = getMatrix(Lhs, Shape, Builder);
MatrixTy B = getMatrix(Rhs, Shape, Builder);
MatrixTy A = getMatrix(Lhs, SI, Builder);
MatrixTy B = getMatrix(Rhs, SI, Builder);
assert(A.isColumnMajor() == B.isColumnMajor() &&
Result.isColumnMajor() == A.isColumnMajor() &&
"operands must agree on matrix layout");

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
B.getVector(I)));

Expand All @@ -2158,19 +2147,14 @@ class LowerMatrixIntrinsics {
Builder);
}

/// Lower unary operators, if shape information is available.
void VisitUnaryOperator(UnaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");

/// Lower unary operators.
void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
Value *Op = Inst->getOperand(0);

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

MatrixTy Result;
MatrixTy M = getMatrix(Op, Shape, Builder);
MatrixTy M = getMatrix(Op, SI, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

Expand All @@ -2184,7 +2168,7 @@ class LowerMatrixIntrinsics {
}
};

for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(BuildVectorOp(M.getVector(I)));

finalizeLowering(Inst,
Expand Down
Loading