diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 026f2fa96146a..f96cad20f9487 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1134,26 +1134,28 @@ class LowerMatrixIntrinsics { if (FusedInsts.count(Inst)) continue; - IRBuilder<> Builder(Inst); - const ShapeInfo &SI = ShapeMap.at(Inst); Value *Op1; Value *Op2; + MatrixTy Result; if (auto *BinOp = dyn_cast(Inst)) - VisitBinaryOperator(BinOp, SI); + Result = VisitBinaryOperator(BinOp, SI); else if (auto *Cast = dyn_cast(Inst)) - VisitCastInstruction(Cast, SI); + Result = VisitCastInstruction(Cast, SI); else if (auto *UnOp = dyn_cast(Inst)) - VisitUnaryOperator(UnOp, SI); - else if (IntrinsicInst *Intr = dyn_cast(Inst)) - VisitIntrinsicInst(Intr, SI); + Result = VisitUnaryOperator(UnOp, SI); + else if (auto *Intr = dyn_cast(Inst)) + Result = VisitIntrinsicInst(Intr, SI); else if (match(Inst, m_Load(m_Value(Op1)))) - VisitLoad(cast(Inst), SI, Op1, Builder); + Result = VisitLoad(cast(Inst), SI, Op1); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) - VisitStore(cast(Inst), SI, Op1, Op2, Builder); + Result = VisitStore(cast(Inst), SI, Op1, Op2); else continue; + + IRBuilder<> Builder(Inst); + finalizeLowering(Inst, Result, Builder); Changed = true; } @@ -1193,25 +1195,24 @@ class LowerMatrixIntrinsics { } /// Replace intrinsic calls. - void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) { - switch (Inst->getIntrinsicID()) { + MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) { + assert(Inst->getCalledFunction() && + Inst->getCalledFunction()->isIntrinsic()); + + switch (Inst->getCalledFunction()->getIntrinsicID()) { case Intrinsic::matrix_multiply: - LowerMultiply(Inst); - return; + return LowerMultiply(Inst); case Intrinsic::matrix_transpose: - LowerTranspose(Inst); - return; + return LowerTranspose(Inst); case Intrinsic::matrix_column_major_load: - LowerColumnMajorLoad(Inst); - return; + return LowerColumnMajorLoad(Inst); case Intrinsic::matrix_column_major_store: - LowerColumnMajorStore(Inst); - return; + return LowerColumnMajorStore(Inst); case Intrinsic::abs: case Intrinsic::fabs: { IRBuilder<> Builder(Inst); MatrixTy Result; - MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder); + MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder); Builder.setFastMathFlags(getFastMathFlags(Inst)); for (auto &Vector : M.vectors()) { @@ -1229,16 +1230,14 @@ class LowerMatrixIntrinsics { } } - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); - return; + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } default: - llvm_unreachable( - "only intrinsics supporting shape info should be seen here"); + break; } + llvm_unreachable( + "only intrinsics supporting shape info should be seen here"); } /// Compute the alignment for a column/row \p Idx with \p Stride between them. @@ -1304,26 +1303,24 @@ class LowerMatrixIntrinsics { } /// Lower a load instruction with shape information. - void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, - bool IsVolatile, ShapeInfo Shape) { + MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, + Value *Stride, bool IsVolatile, ShapeInfo Shape) { IRBuilder<> Builder(Inst); - finalizeLowering(Inst, - loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, - Shape, Builder), - Builder); + return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape, + Builder); } /// Lowers llvm.matrix.column.major.load. /// /// The intrinsic loads a matrix from memory using a stride between columns. - void LowerColumnMajorLoad(CallInst *Inst) { + MatrixTy LowerColumnMajorLoad(CallInst *Inst) { assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && "Intrinsic only supports column-major layout!"); Value *Ptr = Inst->getArgOperand(0); Value *Stride = Inst->getArgOperand(1); - LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, - cast(Inst->getArgOperand(2))->isOne(), - {Inst->getArgOperand(3), Inst->getArgOperand(4)}); + return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, + cast(Inst->getArgOperand(2))->isOne(), + {Inst->getArgOperand(3), Inst->getArgOperand(4)}); } /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p @@ -1366,28 +1363,27 @@ class LowerMatrixIntrinsics { } /// Lower a store instruction with shape information. - void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, - Value *Stride, bool IsVolatile, ShapeInfo Shape) { + MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, + MaybeAlign A, Value *Stride, bool IsVolatile, + ShapeInfo Shape) { IRBuilder<> Builder(Inst); auto StoreVal = getMatrix(Matrix, Shape, Builder); - finalizeLowering(Inst, - storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, - IsVolatile, Builder), - Builder); + return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile, + Builder); } /// Lowers llvm.matrix.column.major.store. /// /// The intrinsic store a matrix back memory using a stride between columns. - void LowerColumnMajorStore(CallInst *Inst) { + MatrixTy LowerColumnMajorStore(CallInst *Inst) { assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && "Intrinsic only supports column-major layout!"); Value *Matrix = Inst->getArgOperand(0); Value *Ptr = Inst->getArgOperand(1); Value *Stride = Inst->getArgOperand(2); - LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, - cast(Inst->getArgOperand(3))->isOne(), - {Inst->getArgOperand(4), Inst->getArgOperand(5)}); + return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, + cast(Inst->getArgOperand(3))->isOne(), + {Inst->getArgOperand(4), Inst->getArgOperand(5)}); } // Set elements I..I+NumElts-1 to Block @@ -2162,7 +2158,7 @@ class LowerMatrixIntrinsics { } /// Lowers llvm.matrix.multiply. - void LowerMultiply(CallInst *MatMul) { + MatrixTy LowerMultiply(CallInst *MatMul) { IRBuilder<> Builder(MatMul); auto *EltType = cast(MatMul->getType())->getElementType(); ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); @@ -2184,11 +2180,11 @@ class LowerMatrixIntrinsics { emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, getFastMathFlags(MatMul)); - finalizeLowering(MatMul, Result, Builder); + return Result; } /// Lowers llvm.matrix.transpose. - void LowerTranspose(CallInst *Inst) { + MatrixTy LowerTranspose(CallInst *Inst) { MatrixTy Result; IRBuilder<> Builder(Inst); Value *InputVal = Inst->getArgOperand(0); @@ -2218,28 +2214,26 @@ class LowerMatrixIntrinsics { // TODO: Improve estimate of operations needed for transposes. Currently we // just count the insertelement/extractelement instructions, but do not // account for later simplifications/combines. - finalizeLowering( - Inst, - Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) - .addNumExposedTransposes(1), - Builder); + return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) + .addNumExposedTransposes(1); } /// Lower load instructions. - void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, - IRBuilder<> &Builder) { - LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()), - Inst->isVolatile(), SI); + MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) { + IRBuilder<> Builder(Inst); + return LowerLoad(Inst, Ptr, Inst->getAlign(), + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); } - void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, - Value *Ptr, IRBuilder<> &Builder) { - LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), - Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); + MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, + Value *Ptr) { + IRBuilder<> Builder(Inst); + return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); } /// Lower binary operators. - void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) { + MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) { Value *Lhs = Inst->getOperand(0); Value *Rhs = Inst->getOperand(1); @@ -2258,14 +2252,12 @@ class LowerMatrixIntrinsics { Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I), B.getVector(I))); - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } /// Lower unary operators. - void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) { + MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) { Value *Op = Inst->getOperand(0); IRBuilder<> Builder(Inst); @@ -2288,14 +2280,12 @@ class LowerMatrixIntrinsics { for (unsigned I = 0; I < SI.getNumVectors(); ++I) Result.addVector(BuildVectorOp(M.getVector(I))); - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } /// Lower cast instructions. - void VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) { + MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) { Value *Op = Inst->getOperand(0); IRBuilder<> Builder(Inst); @@ -2312,10 +2302,8 @@ class LowerMatrixIntrinsics { for (auto &Vector : M.vectors()) Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy)); - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } /// Helper to linearize a matrix expression tree into a string. Currently