Skip to content

Commit 138e0ff

Browse files
authored
[Matrix] (NFC) Refactor sharing of shape information (llvm#164774)
1 parent 2527b07 commit 138e0ff

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,14 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
245245

246246
} // namespace
247247

248-
static bool isUniformShape(Value *V) {
248+
static bool isShapePreserving(Value *V) {
249249
Instruction *I = dyn_cast<Instruction>(V);
250250
if (!I)
251251
return true;
252252

253+
if (isa<SelectInst>(I))
254+
return true;
255+
253256
if (I->isBinaryOp())
254257
return true;
255258

@@ -300,6 +303,16 @@ static bool isUniformShape(Value *V) {
300303
}
301304
}
302305

306+
/// Return an iterator over the operands of \p I that should share shape
307+
/// information with \p I.
308+
static iterator_range<Use *> getShapedOperandsForInst(Instruction *I) {
309+
assert(isShapePreserving(I) &&
310+
"Can't retrieve shaped operands for an instruction that does not "
311+
"preserve shape information");
312+
auto Ops = I->operands();
313+
return isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
314+
}
315+
303316
/// Return the ShapeInfo for the result of \p I, it it can be determined.
304317
static std::optional<ShapeInfo>
305318
computeShapeInfoForInst(Instruction *I,
@@ -329,9 +342,8 @@ computeShapeInfoForInst(Instruction *I,
329342
return OpShape->second;
330343
}
331344

332-
if (isUniformShape(I) || isa<SelectInst>(I)) {
333-
auto Ops = I->operands();
334-
auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
345+
if (isShapePreserving(I)) {
346+
auto ShapedOps = getShapedOperandsForInst(I);
335347
// Find the first operand that has a known shape and use that.
336348
for (auto &Op : ShapedOps) {
337349
auto OpShape = ShapeMap.find(Op.get());
@@ -710,10 +722,9 @@ class LowerMatrixIntrinsics {
710722
case Intrinsic::matrix_column_major_store:
711723
return true;
712724
default:
713-
return isUniformShape(II);
725+
break;
714726
}
715-
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
716-
isa<SelectInst>(V);
727+
return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
717728
}
718729

719730
/// Propagate the shape information of instructions to their users.
@@ -800,9 +811,8 @@ class LowerMatrixIntrinsics {
800811
} else if (isa<StoreInst>(V)) {
801812
// Nothing to do. We forward-propagated to this so we would just
802813
// backward propagate to an instruction with an already known shape.
803-
} else if (isUniformShape(V) || isa<SelectInst>(V)) {
804-
auto Ops = cast<Instruction>(V)->operands();
805-
auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
814+
} else if (isShapePreserving(V)) {
815+
auto ShapedOps = getShapedOperandsForInst(cast<Instruction>(V));
806816
// Propagate to all operands.
807817
ShapeInfo Shape = ShapeMap[V];
808818
for (Use &U : ShapedOps) {

0 commit comments

Comments
 (0)