@@ -1073,8 +1073,8 @@ class LowerMatrixIntrinsics {
10731073 VisitBinaryOperator (BinOp, SI);
10741074 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10751075 VisitUnaryOperator (UnOp, SI);
1076- else if (CallInst *CInst = dyn_cast<CallInst >(Inst))
1077- VisitCallInst (CInst );
1076+ else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst >(Inst))
1077+ VisitIntrinsicInst (Intr, SI );
10781078 else if (match (Inst, m_Load (m_Value (Op1))))
10791079 VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
10801080 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
@@ -1120,11 +1120,8 @@ class LowerMatrixIntrinsics {
11201120 }
11211121
11221122 // / Replace intrinsic calls.
1123- void VisitCallInst (CallInst *Inst) {
1124- assert (Inst->getCalledFunction () &&
1125- Inst->getCalledFunction ()->isIntrinsic ());
1126-
1127- switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
1123+ void VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &Shape) {
1124+ switch (Inst->getIntrinsicID ()) {
11281125 case Intrinsic::matrix_multiply:
11291126 LowerMultiply (Inst);
11301127 break ;
@@ -1138,8 +1135,36 @@ class LowerMatrixIntrinsics {
11381135 LowerColumnMajorStore (Inst);
11391136 break ;
11401137 case Intrinsic::abs:
1141- case Intrinsic::fabs:
1142- return VisitUniformIntrinsic (cast<IntrinsicInst>(Inst));
1138+ case Intrinsic::fabs: {
1139+ IRBuilder<> Builder (Inst);
1140+
1141+ MatrixTy Result;
1142+
1143+ MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape, Builder);
1144+
1145+ Builder.setFastMathFlags (getFastMathFlags (Inst));
1146+
1147+ for (auto &Vector : M.vectors ()) {
1148+ switch (Inst->getIntrinsicID ()) {
1149+ case Intrinsic::abs:
1150+ Result.addVector (Builder.CreateBinaryIntrinsic (Intrinsic::abs, Vector,
1151+ Inst->getOperand (1 )));
1152+ break ;
1153+ case Intrinsic::fabs:
1154+ Result.addVector (
1155+ Builder.CreateUnaryIntrinsic (Inst->getIntrinsicID (), Vector));
1156+ break ;
1157+ default :
1158+ llvm_unreachable (" unexpected intrinsic" );
1159+ }
1160+ }
1161+
1162+ finalizeLowering (Inst,
1163+ Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1164+ Result.getNumVectors ()),
1165+ Builder);
1166+ return ;
1167+ }
11431168 default :
11441169 llvm_unreachable (
11451170 " only intrinsics supporting shape info should be seen here" );
@@ -2189,49 +2214,6 @@ class LowerMatrixIntrinsics {
21892214 Builder);
21902215 }
21912216
2192- // / Lower uniform shape intrinsics, if shape information is available.
2193- bool VisitUniformIntrinsic (IntrinsicInst *Inst) {
2194- auto I = ShapeMap.find (Inst);
2195- assert (I != ShapeMap.end () &&
2196- " must only visit instructions with shape info" );
2197-
2198- IRBuilder<> Builder (Inst);
2199- ShapeInfo &Shape = I->second ;
2200-
2201- MatrixTy Result;
2202-
2203- switch (Inst->getIntrinsicID ()) {
2204- case Intrinsic::abs:
2205- case Intrinsic::fabs: {
2206- MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape, Builder);
2207-
2208- Builder.setFastMathFlags (getFastMathFlags (Inst));
2209-
2210- for (auto &Vector : M.vectors ())
2211- switch (Inst->getIntrinsicID ()) {
2212- case Intrinsic::abs:
2213- Result.addVector (Builder.CreateBinaryIntrinsic (Intrinsic::abs, Vector,
2214- Inst->getOperand (1 )));
2215- break ;
2216- case Intrinsic::fabs:
2217- Result.addVector (
2218- Builder.CreateUnaryIntrinsic (Inst->getIntrinsicID (), Vector));
2219- break ;
2220- default :
2221- llvm_unreachable (" unexpected intrinsic" );
2222- }
2223-
2224- finalizeLowering (Inst,
2225- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2226- Result.getNumVectors ()),
2227- Builder);
2228- return true ;
2229- }
2230- default :
2231- llvm_unreachable (" unexpected intrinsic" );
2232- }
2233- }
2234-
22352217 // / Helper to linearize a matrix expression tree into a string. Currently
22362218 // / matrix expressions are linarized by starting at an expression leaf and
22372219 // / linearizing bottom up.
0 commit comments