@@ -1054,24 +1054,26 @@ class LowerMatrixIntrinsics {
10541054 if (FusedInsts.count (Inst))
10551055 continue ;
10561056
1057- IRBuilder<> Builder (Inst);
1058-
10591057 const ShapeInfo &SI = ShapeMap.at (Inst);
10601058
10611059 Value *Op1;
10621060 Value *Op2;
1061+ MatrixTy Result;
10631062 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1064- VisitBinaryOperator (BinOp, SI);
1063+ Result = VisitBinaryOperator (BinOp, SI);
10651064 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1066- VisitUnaryOperator (UnOp, SI);
1067- else if (CallInst *CInst = dyn_cast<CallInst >(Inst))
1068- VisitCallInst (CInst );
1065+ Result = VisitUnaryOperator (UnOp, SI);
1066+ else if (auto *Intr = dyn_cast<IntrinsicInst >(Inst))
1067+ Result = VisitIntrinsicInst (Intr, SI );
10691068 else if (match (Inst, m_Load (m_Value (Op1))))
1070- VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder );
1069+ Result = VisitLoad (cast<LoadInst>(Inst), SI, Op1);
10711070 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1072- VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder );
1071+ Result = VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2);
10731072 else
10741073 continue ;
1074+
1075+ IRBuilder<> Builder (Inst);
1076+ finalizeLowering (Inst, Result, Builder);
10751077 Changed = true ;
10761078 }
10771079
@@ -1111,27 +1113,24 @@ class LowerMatrixIntrinsics {
11111113 }
11121114
11131115 // / Replace intrinsic calls.
1114- void VisitCallInst (CallInst *Inst) {
1116+ MatrixTy VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &SI ) {
11151117 assert (Inst->getCalledFunction () &&
11161118 Inst->getCalledFunction ()->isIntrinsic ());
11171119
11181120 switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
11191121 case Intrinsic::matrix_multiply:
1120- LowerMultiply (Inst);
1121- break ;
1122+ return LowerMultiply (Inst);
11221123 case Intrinsic::matrix_transpose:
1123- LowerTranspose (Inst);
1124- break ;
1124+ return LowerTranspose (Inst);
11251125 case Intrinsic::matrix_column_major_load:
1126- LowerColumnMajorLoad (Inst);
1127- break ;
1126+ return LowerColumnMajorLoad (Inst);
11281127 case Intrinsic::matrix_column_major_store:
1129- LowerColumnMajorStore (Inst);
1130- break ;
1128+ return LowerColumnMajorStore (Inst);
11311129 default :
1132- llvm_unreachable (
1133- " only intrinsics supporting shape info should be seen here" );
1130+ break ;
11341131 }
1132+ llvm_unreachable (
1133+ " only intrinsics supporting shape info should be seen here" );
11351134 }
11361135
11371136 // / Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1197,26 +1196,24 @@ class LowerMatrixIntrinsics {
11971196 }
11981197
11991198 // / Lower a load instruction with shape information.
1200- void LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride ,
1201- bool IsVolatile, ShapeInfo Shape) {
1199+ MatrixTy LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align,
1200+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
12021201 IRBuilder<> Builder (Inst);
1203- finalizeLowering (Inst,
1204- loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile,
1205- Shape, Builder),
1206- Builder);
1202+ return loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile, Shape,
1203+ Builder);
12071204 }
12081205
12091206 // / Lowers llvm.matrix.column.major.load.
12101207 // /
12111208 // / The intrinsic loads a matrix from memory using a stride between columns.
1212- void LowerColumnMajorLoad (CallInst *Inst) {
1209+ MatrixTy LowerColumnMajorLoad (CallInst *Inst) {
12131210 assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
12141211 " Intrinsic only supports column-major layout!" );
12151212 Value *Ptr = Inst->getArgOperand (0 );
12161213 Value *Stride = Inst->getArgOperand (1 );
1217- LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1218- cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1219- {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
1214+ return LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1215+ cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1216+ {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
12201217 }
12211218
12221219 // / Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1259,28 +1256,27 @@ class LowerMatrixIntrinsics {
12591256 }
12601257
12611258 // / Lower a store instruction with shape information.
1262- void LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1263- Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1259+ MatrixTy LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr,
1260+ MaybeAlign A, Value *Stride, bool IsVolatile,
1261+ ShapeInfo Shape) {
12641262 IRBuilder<> Builder (Inst);
12651263 auto StoreVal = getMatrix (Matrix, Shape, Builder);
1266- finalizeLowering (Inst,
1267- storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride,
1268- IsVolatile, Builder),
1269- Builder);
1264+ return storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride, IsVolatile,
1265+ Builder);
12701266 }
12711267
12721268 // / Lowers llvm.matrix.column.major.store.
12731269 // /
12741270 // / The intrinsic store a matrix back memory using a stride between columns.
1275- void LowerColumnMajorStore (CallInst *Inst) {
1271+ MatrixTy LowerColumnMajorStore (CallInst *Inst) {
12761272 assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
12771273 " Intrinsic only supports column-major layout!" );
12781274 Value *Matrix = Inst->getArgOperand (0 );
12791275 Value *Ptr = Inst->getArgOperand (1 );
12801276 Value *Stride = Inst->getArgOperand (2 );
1281- LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1282- cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1283- {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
1277+ return LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1278+ cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1279+ {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
12841280 }
12851281
12861282 // Set elements I..I+NumElts-1 to Block
@@ -2045,7 +2041,7 @@ class LowerMatrixIntrinsics {
20452041 }
20462042
20472043 // / Lowers llvm.matrix.multiply.
2048- void LowerMultiply (CallInst *MatMul) {
2044+ MatrixTy LowerMultiply (CallInst *MatMul) {
20492045 IRBuilder<> Builder (MatMul);
20502046 auto *EltType = cast<FixedVectorType>(MatMul->getType ())->getElementType ();
20512047 ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
@@ -2067,11 +2063,11 @@ class LowerMatrixIntrinsics {
20672063
20682064 emitMatrixMultiply (Result, Lhs, Rhs, Builder, false , false ,
20692065 getFastMathFlags (MatMul));
2070- finalizeLowering (MatMul, Result, Builder) ;
2066+ return Result;
20712067 }
20722068
20732069 // / Lowers llvm.matrix.transpose.
2074- void LowerTranspose (CallInst *Inst) {
2070+ MatrixTy LowerTranspose (CallInst *Inst) {
20752071 MatrixTy Result;
20762072 IRBuilder<> Builder (Inst);
20772073 Value *InputVal = Inst->getArgOperand (0 );
@@ -2101,28 +2097,26 @@ class LowerMatrixIntrinsics {
21012097 // TODO: Improve estimate of operations needed for transposes. Currently we
21022098 // just count the insertelement/extractelement instructions, but do not
21032099 // account for later simplifications/combines.
2104- finalizeLowering (
2105- Inst,
2106- Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns )
2107- .addNumExposedTransposes (1 ),
2108- Builder);
2100+ return Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns )
2101+ .addNumExposedTransposes (1 );
21092102 }
21102103
21112104 // / Lower load instructions.
2112- void VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2113- IRBuilder<> & Builder) {
2114- LowerLoad (Inst, Ptr, Inst->getAlign (), Builder. getInt64 (SI. getStride () ),
2115- Inst->isVolatile (), SI);
2105+ MatrixTy VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
2106+ IRBuilder<> Builder (Inst);
2107+ return LowerLoad (Inst, Ptr, Inst->getAlign (),
2108+ Builder. getInt64 (SI. getStride ()), Inst->isVolatile (), SI);
21162109 }
21172110
2118- void VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2119- Value *Ptr, IRBuilder<> &Builder) {
2120- LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2121- Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
2111+ MatrixTy VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2112+ Value *Ptr) {
2113+ IRBuilder<> Builder (Inst);
2114+ return LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2115+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
21222116 }
21232117
21242118 // / Lower binary operators.
2125- void VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
2119+ MatrixTy VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
21262120 Value *Lhs = Inst->getOperand (0 );
21272121 Value *Rhs = Inst->getOperand (1 );
21282122
@@ -2141,14 +2135,12 @@ class LowerMatrixIntrinsics {
21412135 Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
21422136 B.getVector (I)));
21432137
2144- finalizeLowering (Inst,
2145- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2146- Result.getNumVectors ()),
2147- Builder);
2138+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2139+ Result.getNumVectors ());
21482140 }
21492141
21502142 // / Lower unary operators.
2151- void VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
2143+ MatrixTy VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
21522144 Value *Op = Inst->getOperand (0 );
21532145
21542146 IRBuilder<> Builder (Inst);
@@ -2171,10 +2163,8 @@ class LowerMatrixIntrinsics {
21712163 for (unsigned I = 0 ; I < SI.getNumVectors (); ++I)
21722164 Result.addVector (BuildVectorOp (M.getVector (I)));
21732165
2174- finalizeLowering (Inst,
2175- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2176- Result.getNumVectors ()),
2177- Builder);
2166+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2167+ Result.getNumVectors ());
21782168 }
21792169
21802170 // / Helper to linearize a matrix expression tree into a string. Currently
0 commit comments