@@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
232232 if (I->isBinaryOp ())
233233 return true ;
234234
235+ if (auto *II = dyn_cast<IntrinsicInst>(V))
236+ switch (II->getIntrinsicID ()) {
237+ case Intrinsic::abs:
238+ case Intrinsic::fabs:
239+ return true ;
240+ default :
241+ return false ;
242+ }
243+
235244 switch (I->getOpcode ()) {
236245 case Instruction::FNeg:
237246 return true ;
@@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
618627 case Intrinsic::matrix_column_major_store:
619628 return true ;
620629 default :
621- return false ;
630+ return isUniformShape (II) ;
622631 }
623632 return isUniformShape (V) || isa<StoreInst>(V) || isa<LoadInst>(V);
624633 }
@@ -1064,8 +1073,8 @@ class LowerMatrixIntrinsics {
10641073 VisitBinaryOperator (BinOp, SI);
10651074 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10661075 VisitUnaryOperator (UnOp, SI);
1067- else if (CallInst *CInst = dyn_cast<CallInst >(Inst))
1068- VisitCallInst (CInst );
1076+ else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst >(Inst))
1077+ VisitIntrinsicInst (Intr, SI );
10691078 else if (match (Inst, m_Load (m_Value (Op1))))
10701079 VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
10711080 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
@@ -1111,23 +1120,48 @@ class LowerMatrixIntrinsics {
11111120 }
11121121
11131122 // / Replace intrinsic calls.
1114- void VisitCallInst (CallInst *Inst) {
1115- assert (Inst->getCalledFunction () &&
1116- Inst->getCalledFunction ()->isIntrinsic ());
1117-
1118- switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
1123+ void VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &Shape) {
1124+ switch (Inst->getIntrinsicID ()) {
11191125 case Intrinsic::matrix_multiply:
11201126 LowerMultiply (Inst);
1121- break ;
1127+ return ;
11221128 case Intrinsic::matrix_transpose:
11231129 LowerTranspose (Inst);
1124- break ;
1130+ return ;
11251131 case Intrinsic::matrix_column_major_load:
11261132 LowerColumnMajorLoad (Inst);
1127- break ;
1133+ return ;
11281134 case Intrinsic::matrix_column_major_store:
11291135 LowerColumnMajorStore (Inst);
1130- break ;
1136+ return ;
1137+ case Intrinsic::abs:
1138+ case Intrinsic::fabs: {
1139+ IRBuilder<> Builder (Inst);
1140+ MatrixTy Result;
1141+ MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape, Builder);
1142+ Builder.setFastMathFlags (getFastMathFlags (Inst));
1143+
1144+ for (auto &Vector : M.vectors ()) {
1145+ switch (Inst->getIntrinsicID ()) {
1146+ case Intrinsic::abs:
1147+ Result.addVector (Builder.CreateBinaryIntrinsic (Intrinsic::abs, Vector,
1148+ Inst->getOperand (1 )));
1149+ continue ;
1150+ case Intrinsic::fabs:
1151+ Result.addVector (
1152+ Builder.CreateUnaryIntrinsic (Inst->getIntrinsicID (), Vector));
1153+ continue ;
1154+ default :
1155+ llvm_unreachable (" unexpected intrinsic" );
1156+ }
1157+ }
1158+
1159+ finalizeLowering (Inst,
1160+ Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1161+ Result.getNumVectors ()),
1162+ Builder);
1163+ return ;
1164+ }
11311165 default :
11321166 llvm_unreachable (
11331167 " only intrinsics supporting shape info should be seen here" );
0 commit comments