@@ -288,6 +288,7 @@ static bool isUniformShape(Value *V) {
288288 }
289289
290290 switch (I->getOpcode ()) {
291+ case Instruction::PHI:
291292 case Instruction::FNeg:
292293 return true ;
293294 default :
@@ -1136,7 +1137,27 @@ class LowerMatrixIntrinsics {
11361137
11371138 Changed |= !FusedInsts.empty ();
11381139
1139- // Fourth, lower remaining instructions with shape information.
1140+ // Fourth, pre-process all the PHINode's. The incoming values will be
1141+ // assigned later in VisitPHI.
1142+ for (Instruction *Inst : MatrixInsts) {
1143+ auto *PHI = dyn_cast<PHINode>(Inst);
1144+ if (!PHI)
1145+ continue ;
1146+
1147+ const ShapeInfo &SI = ShapeMap.at (Inst);
1148+ auto *EltTy = cast<FixedVectorType>(PHI->getType ())->getElementType ();
1149+ MatrixTy PhiM (SI.NumRows , SI.NumColumns , EltTy);
1150+
1151+ IRBuilder<> Builder (Inst);
1152+ for (unsigned VI = 0 , VE = PhiM.getNumVectors (); VI != VE; ++VI)
1153+ PhiM.setVector (VI, Builder.CreatePHI (PhiM.getVectorTy (),
1154+ PHI->getNumIncomingValues (),
1155+ PHI->getName ()));
1156+ assert (!Inst2ColumnMatrix.contains (PHI) && " map already contains phi?" );
1157+ Inst2ColumnMatrix[PHI] = PhiM;
1158+ }
1159+
1160+ // Fifth, lower remaining instructions with shape information.
11401161 for (Instruction *Inst : MatrixInsts) {
11411162 if (FusedInsts.count (Inst))
11421163 continue ;
@@ -1161,6 +1182,8 @@ class LowerMatrixIntrinsics {
11611182 Result = VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
11621183 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
11631184 Result = VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1185+ else if (auto *PHI = dyn_cast<PHINode>(Inst))
1186+ Result = VisitPHI (PHI, SI, Builder);
11641187 else
11651188 continue ;
11661189
@@ -1458,7 +1481,8 @@ class LowerMatrixIntrinsics {
14581481 IRBuilder<> &Builder) {
14591482 auto inserted = Inst2ColumnMatrix.insert (std::make_pair (Inst, Matrix));
14601483 (void )inserted;
1461- assert (inserted.second && " multiple matrix lowering mapping" );
1484+ assert ((inserted.second || isa<PHINode>(Inst)) &&
1485+ " multiple matrix lowering mapping" );
14621486
14631487 ToRemove.push_back (Inst);
14641488 Value *Flattened = nullptr ;
@@ -2239,6 +2263,35 @@ class LowerMatrixIntrinsics {
22392263 Builder);
22402264 }
22412265
2266+ MatrixTy VisitPHI (PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2267+ auto BlockIP = Inst->getParent ()->getFirstInsertionPt ();
2268+ Builder.SetInsertPoint (BlockIP);
2269+ MatrixTy PhiM = getMatrix (Inst, SI, Builder);
2270+
2271+ for (auto [IncomingV, IncomingB] :
2272+ llvm::zip_equal (Inst->incoming_values (), Inst->blocks ())) {
2273+ // getMatrix() may insert some instructions to help with reshaping. The
2274+ // safest place for those is at the top of the block after the rest of the
2275+ // PHI's. Even better, if we can put it in the incoming block.
2276+ Builder.SetInsertPoint (BlockIP);
2277+ if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
2278+ if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef ())
2279+ Builder.SetInsertPoint (*MaybeIP);
2280+
2281+ MatrixTy OpM = getMatrix (IncomingV, SI, Builder);
2282+
2283+ for (unsigned VI = 0 , VE = PhiM.getNumVectors (); VI != VE; ++VI) {
2284+ PHINode *NewPHI = cast<PHINode>(PhiM.getVector (VI));
2285+ NewPHI->addIncoming (OpM.getVector (VI), IncomingB);
2286+ }
2287+ }
2288+
2289+ // finalizeLowering() may also insert instructions in some cases. The safe
2290+ // place for those is at the end of the initial block of PHIs.
2291+ Builder.SetInsertPoint (BlockIP);
2292+ return PhiM;
2293+ }
2294+
22422295 // / Lower binary operators.
22432296 MatrixTy VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI,
22442297 IRBuilder<> &Builder) {
0 commit comments