@@ -245,11 +245,14 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
245245
246246} // namespace
247247
248- static bool isUniformShape (Value *V) {
248+ static bool isShapePreserving (Value *V) {
249249 Instruction *I = dyn_cast<Instruction>(V);
250250 if (!I)
251251 return true ;
252252
253+ if (isa<SelectInst>(I))
254+ return true ;
255+
253256 if (I->isBinaryOp ())
254257 return true ;
255258
@@ -300,6 +303,16 @@ static bool isUniformShape(Value *V) {
300303 }
301304}
302305
306+ // / Return an iterator over the operands of \p I that should share shape
307+ // / information with \p I.
308+ static iterator_range<Use *> getShapedOperandsForInst (Instruction *I) {
309+ assert (isShapePreserving (I) &&
310+ " Can't retrieve shaped operands for an instruction that does not "
311+ " preserve shape information" );
312+ auto Ops = I->operands ();
313+ return isa<SelectInst>(I) ? drop_begin (Ops) : Ops;
314+ }
315+
303316// / Return the ShapeInfo for the result of \p I, it it can be determined.
304317static std::optional<ShapeInfo>
305318computeShapeInfoForInst (Instruction *I,
@@ -329,9 +342,8 @@ computeShapeInfoForInst(Instruction *I,
329342 return OpShape->second ;
330343 }
331344
332- if (isUniformShape (I) || isa<SelectInst>(I)) {
333- auto Ops = I->operands ();
334- auto ShapedOps = isa<SelectInst>(I) ? drop_begin (Ops) : Ops;
345+ if (isShapePreserving (I)) {
346+ auto ShapedOps = getShapedOperandsForInst (I);
335347 // Find the first operand that has a known shape and use that.
336348 for (auto &Op : ShapedOps) {
337349 auto OpShape = ShapeMap.find (Op.get ());
@@ -710,10 +722,9 @@ class LowerMatrixIntrinsics {
710722 case Intrinsic::matrix_column_major_store:
711723 return true ;
712724 default :
713- return isUniformShape (II) ;
725+ break ;
714726 }
715- return isUniformShape (V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
716- isa<SelectInst>(V);
727+ return isShapePreserving (V) || isa<StoreInst>(V) || isa<LoadInst>(V);
717728 }
718729
719730 // / Propagate the shape information of instructions to their users.
@@ -800,9 +811,8 @@ class LowerMatrixIntrinsics {
800811 } else if (isa<StoreInst>(V)) {
801812 // Nothing to do. We forward-propagated to this so we would just
802813 // backward propagate to an instruction with an already known shape.
803- } else if (isUniformShape (V) || isa<SelectInst>(V)) {
804- auto Ops = cast<Instruction>(V)->operands ();
805- auto ShapedOps = isa<SelectInst>(V) ? drop_begin (Ops) : Ops;
814+ } else if (isShapePreserving (V)) {
815+ auto ShapedOps = getShapedOperandsForInst (cast<Instruction>(V));
806816 // Propagate to all operands.
807817 ShapeInfo Shape = ShapeMap[V];
808818 for (Use &U : ShapedOps) {
0 commit comments