diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 38f92561a917d..20279bf69dd59 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1056,19 +1056,20 @@ class LowerMatrixIntrinsics { IRBuilder<> Builder(Inst); - if (CallInst *CInst = dyn_cast(Inst)) - Changed |= VisitCallInst(CInst); + const ShapeInfo &SI = ShapeMap.at(Inst); Value *Op1; Value *Op2; if (auto *BinOp = dyn_cast(Inst)) - VisitBinaryOperator(BinOp); + VisitBinaryOperator(BinOp, SI); else if (auto *UnOp = dyn_cast(Inst)) - VisitUnaryOperator(UnOp); + VisitUnaryOperator(UnOp, SI); + else if (CallInst *CInst = dyn_cast(Inst)) + VisitCallInst(CInst); else if (match(Inst, m_Load(m_Value(Op1)))) - VisitLoad(cast(Inst), Op1, Builder); + VisitLoad(cast(Inst), SI, Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) - VisitStore(cast(Inst), Op1, Op2, Builder); + VisitStore(cast(Inst), SI, Op1, Op2, Builder); else continue; Changed = true; @@ -1109,10 +1110,10 @@ class LowerMatrixIntrinsics { return Changed; } - /// Replace intrinsic calls - bool VisitCallInst(CallInst *Inst) { - if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) - return false; + /// Replace intrinsic calls. + void VisitCallInst(CallInst *Inst) { + assert(Inst->getCalledFunction() && + Inst->getCalledFunction()->isIntrinsic()); switch (Inst->getCalledFunction()->getIntrinsicID()) { case Intrinsic::matrix_multiply: @@ -1128,9 +1129,9 @@ class LowerMatrixIntrinsics { LowerColumnMajorStore(Inst); break; default: - return false; + llvm_unreachable( + "only intrinsics supporting shape info should be seen here"); } - return true; } /// Compute the alignment for a column/row \p Idx with \p Stride between them. @@ -2107,48 +2108,36 @@ class LowerMatrixIntrinsics { Builder); } - /// Lower load instructions, if shape information is available. - void VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { - auto I = ShapeMap.find(Inst); - assert(I != ShapeMap.end() && - "must only visit instructions with shape info"); - LowerLoad(Inst, Ptr, Inst->getAlign(), - Builder.getInt64(I->second.getStride()), Inst->isVolatile(), - I->second); + /// Lower load instructions. + void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, + IRBuilder<> &Builder) { + LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()), + Inst->isVolatile(), SI); } - void VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, - IRBuilder<> &Builder) { - auto I = ShapeMap.find(Inst); - assert(I != ShapeMap.end() && - "must only visit instructions with shape info"); + void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, + Value *Ptr, IRBuilder<> &Builder) { LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), - Builder.getInt64(I->second.getStride()), Inst->isVolatile(), - I->second); + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); } - /// Lower binary operators, if shape information is available. - void VisitBinaryOperator(BinaryOperator *Inst) { - auto I = ShapeMap.find(Inst); - assert(I != ShapeMap.end() && - "must only visit instructions with shape info"); - + /// Lower binary operators. + void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) { Value *Lhs = Inst->getOperand(0); Value *Rhs = Inst->getOperand(1); IRBuilder<> Builder(Inst); - ShapeInfo &Shape = I->second; MatrixTy Result; - MatrixTy A = getMatrix(Lhs, Shape, Builder); - MatrixTy B = getMatrix(Rhs, Shape, Builder); + MatrixTy A = getMatrix(Lhs, SI, Builder); + MatrixTy B = getMatrix(Rhs, SI, Builder); assert(A.isColumnMajor() == B.isColumnMajor() && Result.isColumnMajor() == A.isColumnMajor() && "operands must agree on matrix layout"); Builder.setFastMathFlags(getFastMathFlags(Inst)); - for (unsigned I = 0; I < Shape.getNumVectors(); ++I) + for (unsigned I = 0; I < SI.getNumVectors(); ++I) Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I), B.getVector(I))); @@ -2158,19 +2147,14 @@ class LowerMatrixIntrinsics { Builder); } - /// Lower unary operators, if shape information is available. - void VisitUnaryOperator(UnaryOperator *Inst) { - auto I = ShapeMap.find(Inst); - assert(I != ShapeMap.end() && - "must only visit instructions with shape info"); - + /// Lower unary operators. + void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) { Value *Op = Inst->getOperand(0); IRBuilder<> Builder(Inst); - ShapeInfo &Shape = I->second; MatrixTy Result; - MatrixTy M = getMatrix(Op, Shape, Builder); + MatrixTy M = getMatrix(Op, SI, Builder); Builder.setFastMathFlags(getFastMathFlags(Inst)); @@ -2184,7 +2168,7 @@ class LowerMatrixIntrinsics { } }; - for (unsigned I = 0; I < Shape.getNumVectors(); ++I) + for (unsigned I = 0; I < SI.getNumVectors(); ++I) Result.addVector(BuildVectorOp(M.getVector(I))); finalizeLowering(Inst,