@@ -270,19 +270,11 @@ computeShapeInfoForInst(Instruction *I,
270270 return OpShape->second ;
271271 }
272272
273- if (auto *Select = dyn_cast<SelectInst>(I)) {
274- Type *CondTy = Select->getCondition ()->getType ();
275- for (Use &Op : CondTy->isVectorTy () ? Select->operands ()
276- : drop_begin (Select->operands ())) {
277- auto OpShape = ShapeMap.find (Op);
278- if (OpShape != ShapeMap.end ())
279- return OpShape->second ;
280- }
281- }
282-
283- if (isUniformShape (I)) {
273+ if (isUniformShape (I) || isa<SelectInst>(I)) {
274+ auto Ops = I->operands ();
275+ auto ShapedOps = isa<SelectInst>(I) ? drop_begin (Ops) : Ops;
284276 // Find the first operand that has a known shape and use that.
285- for (auto &Op : I-> operands () ) {
277+ for (auto &Op : ShapedOps ) {
286278 auto OpShape = ShapeMap.find (Op.get ());
287279 if (OpShape != ShapeMap.end ())
288280 return OpShape->second ;
@@ -719,18 +711,12 @@ class LowerMatrixIntrinsics {
719711 } else if (isa<StoreInst>(V)) {
720712 // Nothing to do. We forward-propagated to this so we would just
721713 // backward propagate to an instruction with an already known shape.
722- } else if (auto *Select = dyn_cast<SelectInst>(V)) {
723- ShapeInfo Shape = ShapeMap[V];
724- Type *CondTy = Select->getCondition ()->getType ();
725- for (Use &Op : CondTy->isVectorTy () ? Select->operands ()
726- : drop_begin (Select->operands ())) {
727- if (setShapeInfo (Op, Shape))
728- pushInstruction (Select, WorkList);
729- }
730- } else if (isUniformShape (V)) {
714+ } else if (isUniformShape (V) || isa<SelectInst>(V)) {
715+ auto Ops = cast<Instruction>(V)->operands ();
716+ auto ShapedOps = isa<SelectInst>(V) ? drop_begin (Ops) : Ops;
731717 // Propagate to all operands.
732718 ShapeInfo Shape = ShapeMap[V];
733- for (Use &U : cast<Instruction>(V)-> operands () ) {
719+ for (Use &U : ShapedOps ) {
734720 if (setShapeInfo (U.get (), Shape))
735721 pushInstruction (U.get (), WorkList);
736722 }
0 commit comments