@@ -1134,26 +1134,28 @@ class LowerMatrixIntrinsics {
11341134 if (FusedInsts.count (Inst))
11351135 continue ;
11361136
1137- IRBuilder<> Builder (Inst);
1138-
11391137 const ShapeInfo &SI = ShapeMap.at (Inst);
11401138
11411139 Value *Op1;
11421140 Value *Op2;
1141+ MatrixTy Result;
11431142 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1144- VisitBinaryOperator (BinOp, SI);
1143+ Result = VisitBinaryOperator (BinOp, SI);
11451144 else if (auto *Cast = dyn_cast<CastInst>(Inst))
1146- VisitCastInstruction (Cast, SI);
1145+ Result = VisitCastInstruction (Cast, SI);
11471146 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1148- VisitUnaryOperator (UnOp, SI);
1149- else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
1150- VisitIntrinsicInst (Intr, SI);
1147+ Result = VisitUnaryOperator (UnOp, SI);
1148+ else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1149+ Result = VisitIntrinsicInst (Intr, SI);
11511150 else if (match (Inst, m_Load (m_Value (Op1))))
1152- VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder );
1151+ Result = VisitLoad (cast<LoadInst>(Inst), SI, Op1);
11531152 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1154- VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder );
1153+ Result = VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2);
11551154 else
11561155 continue ;
1156+
1157+ IRBuilder<> Builder (Inst);
1158+ finalizeLowering (Inst, Result, Builder);
11571159 Changed = true ;
11581160 }
11591161
@@ -1193,25 +1195,24 @@ class LowerMatrixIntrinsics {
11931195 }
11941196
11951197 // / Replace intrinsic calls.
1196- void VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &Shape) {
1197- switch (Inst->getIntrinsicID ()) {
1198+ MatrixTy VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &SI) {
1199+ assert (Inst->getCalledFunction () &&
1200+ Inst->getCalledFunction ()->isIntrinsic ());
1201+
1202+ switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
11981203 case Intrinsic::matrix_multiply:
1199- LowerMultiply (Inst);
1200- return ;
1204+ return LowerMultiply (Inst);
12011205 case Intrinsic::matrix_transpose:
1202- LowerTranspose (Inst);
1203- return ;
1206+ return LowerTranspose (Inst);
12041207 case Intrinsic::matrix_column_major_load:
1205- LowerColumnMajorLoad (Inst);
1206- return ;
1208+ return LowerColumnMajorLoad (Inst);
12071209 case Intrinsic::matrix_column_major_store:
1208- LowerColumnMajorStore (Inst);
1209- return ;
1210+ return LowerColumnMajorStore (Inst);
12101211 case Intrinsic::abs:
12111212 case Intrinsic::fabs: {
12121213 IRBuilder<> Builder (Inst);
12131214 MatrixTy Result;
1214- MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape , Builder);
1215+ MatrixTy M = getMatrix (Inst->getOperand (0 ), SI , Builder);
12151216 Builder.setFastMathFlags (getFastMathFlags (Inst));
12161217
12171218 for (auto &Vector : M.vectors ()) {
@@ -1229,16 +1230,14 @@ class LowerMatrixIntrinsics {
12291230 }
12301231 }
12311232
1232- finalizeLowering (Inst,
1233- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1234- Result.getNumVectors ()),
1235- Builder);
1236- return ;
1233+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1234+ Result.getNumVectors ());
12371235 }
12381236 default :
1239- llvm_unreachable (
1240- " only intrinsics supporting shape info should be seen here" );
1237+ break ;
12411238 }
1239+ llvm_unreachable (
1240+ " only intrinsics supporting shape info should be seen here" );
12421241 }
12431242
12441243 // / Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1304,26 +1303,24 @@ class LowerMatrixIntrinsics {
13041303 }
13051304
13061305 // / Lower a load instruction with shape information.
1307- void LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride ,
1308- bool IsVolatile, ShapeInfo Shape) {
1306+ MatrixTy LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align,
1307+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
13091308 IRBuilder<> Builder (Inst);
1310- finalizeLowering (Inst,
1311- loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile,
1312- Shape, Builder),
1313- Builder);
1309+ return loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile, Shape,
1310+ Builder);
13141311 }
13151312
13161313 // / Lowers llvm.matrix.column.major.load.
13171314 // /
13181315 // / The intrinsic loads a matrix from memory using a stride between columns.
1319- void LowerColumnMajorLoad (CallInst *Inst) {
1316+ MatrixTy LowerColumnMajorLoad (CallInst *Inst) {
13201317 assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
13211318 " Intrinsic only supports column-major layout!" );
13221319 Value *Ptr = Inst->getArgOperand (0 );
13231320 Value *Stride = Inst->getArgOperand (1 );
1324- LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1325- cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1326- {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
1321+ return LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1322+ cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1323+ {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
13271324 }
13281325
13291326 // / Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1366,28 +1363,27 @@ class LowerMatrixIntrinsics {
13661363 }
13671364
13681365 // / Lower a store instruction with shape information.
1369- void LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1370- Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1366+ MatrixTy LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr,
1367+ MaybeAlign A, Value *Stride, bool IsVolatile,
1368+ ShapeInfo Shape) {
13711369 IRBuilder<> Builder (Inst);
13721370 auto StoreVal = getMatrix (Matrix, Shape, Builder);
1373- finalizeLowering (Inst,
1374- storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride,
1375- IsVolatile, Builder),
1376- Builder);
1371+ return storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride, IsVolatile,
1372+ Builder);
13771373 }
13781374
13791375 // / Lowers llvm.matrix.column.major.store.
13801376 // /
13811377 // / The intrinsic store a matrix back memory using a stride between columns.
1382- void LowerColumnMajorStore (CallInst *Inst) {
1378+ MatrixTy LowerColumnMajorStore (CallInst *Inst) {
13831379 assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
13841380 " Intrinsic only supports column-major layout!" );
13851381 Value *Matrix = Inst->getArgOperand (0 );
13861382 Value *Ptr = Inst->getArgOperand (1 );
13871383 Value *Stride = Inst->getArgOperand (2 );
1388- LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1389- cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1390- {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
1384+ return LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1385+ cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1386+ {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
13911387 }
13921388
13931389 // Set elements I..I+NumElts-1 to Block
@@ -2162,7 +2158,7 @@ class LowerMatrixIntrinsics {
21622158 }
21632159
21642160 // / Lowers llvm.matrix.multiply.
2165- void LowerMultiply (CallInst *MatMul) {
2161+ MatrixTy LowerMultiply (CallInst *MatMul) {
21662162 IRBuilder<> Builder (MatMul);
21672163 auto *EltType = cast<FixedVectorType>(MatMul->getType ())->getElementType ();
21682164 ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
@@ -2184,11 +2180,11 @@ class LowerMatrixIntrinsics {
21842180
21852181 emitMatrixMultiply (Result, Lhs, Rhs, Builder, false , false ,
21862182 getFastMathFlags (MatMul));
2187- finalizeLowering (MatMul, Result, Builder) ;
2183+ return Result;
21882184 }
21892185
21902186 // / Lowers llvm.matrix.transpose.
2191- void LowerTranspose (CallInst *Inst) {
2187+ MatrixTy LowerTranspose (CallInst *Inst) {
21922188 MatrixTy Result;
21932189 IRBuilder<> Builder (Inst);
21942190 Value *InputVal = Inst->getArgOperand (0 );
@@ -2218,28 +2214,26 @@ class LowerMatrixIntrinsics {
22182214 // TODO: Improve estimate of operations needed for transposes. Currently we
22192215 // just count the insertelement/extractelement instructions, but do not
22202216 // account for later simplifications/combines.
2221- finalizeLowering (
2222- Inst,
2223- Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns )
2224- .addNumExposedTransposes (1 ),
2225- Builder);
2217+ return Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns )
2218+ .addNumExposedTransposes (1 );
22262219 }
22272220
22282221 // / Lower load instructions.
2229- void VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2230- IRBuilder<> & Builder) {
2231- LowerLoad (Inst, Ptr, Inst->getAlign (), Builder. getInt64 (SI. getStride () ),
2232- Inst->isVolatile (), SI);
2222+ MatrixTy VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
2223+ IRBuilder<> Builder (Inst);
2224+ return LowerLoad (Inst, Ptr, Inst->getAlign (),
2225+ Builder. getInt64 (SI. getStride ()), Inst->isVolatile (), SI);
22332226 }
22342227
2235- void VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2236- Value *Ptr, IRBuilder<> &Builder) {
2237- LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2238- Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
2228+ MatrixTy VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2229+ Value *Ptr) {
2230+ IRBuilder<> Builder (Inst);
2231+ return LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2232+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
22392233 }
22402234
22412235 // / Lower binary operators.
2242- void VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
2236+ MatrixTy VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
22432237 Value *Lhs = Inst->getOperand (0 );
22442238 Value *Rhs = Inst->getOperand (1 );
22452239
@@ -2258,14 +2252,12 @@ class LowerMatrixIntrinsics {
22582252 Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
22592253 B.getVector (I)));
22602254
2261- finalizeLowering (Inst,
2262- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2263- Result.getNumVectors ()),
2264- Builder);
2255+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2256+ Result.getNumVectors ());
22652257 }
22662258
22672259 // / Lower unary operators.
2268- void VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
2260+ MatrixTy VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
22692261 Value *Op = Inst->getOperand (0 );
22702262
22712263 IRBuilder<> Builder (Inst);
@@ -2288,14 +2280,12 @@ class LowerMatrixIntrinsics {
22882280 for (unsigned I = 0 ; I < SI.getNumVectors (); ++I)
22892281 Result.addVector (BuildVectorOp (M.getVector (I)));
22902282
2291- finalizeLowering (Inst,
2292- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2293- Result.getNumVectors ()),
2294- Builder);
2283+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2284+ Result.getNumVectors ());
22952285 }
22962286
22972287 // / Lower cast instructions.
2298- void VisitCastInstruction (CastInst *Inst, const ShapeInfo &Shape) {
2288+ MatrixTy VisitCastInstruction (CastInst *Inst, const ShapeInfo &Shape) {
22992289 Value *Op = Inst->getOperand (0 );
23002290
23012291 IRBuilder<> Builder (Inst);
@@ -2312,10 +2302,8 @@ class LowerMatrixIntrinsics {
23122302 for (auto &Vector : M.vectors ())
23132303 Result.addVector (Builder.CreateCast (Inst->getOpcode (), Vector, NewVTy));
23142304
2315- finalizeLowering (Inst,
2316- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2317- Result.getNumVectors ()),
2318- Builder);
2305+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2306+ Result.getNumVectors ());
23192307 }
23202308
23212309 // / Helper to linearize a matrix expression tree into a string. Currently
0 commit comments