Skip to content

Commit 1a4faa4

Browse files
committed
review feedback: don't take shape info from select conditions
1 parent 6693f56 commit 1a4faa4

File tree

2 files changed

+11
-25
lines changed

2 files changed

+11
-25
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,11 @@ computeShapeInfoForInst(Instruction *I,
270270
return OpShape->second;
271271
}
272272

273-
if (auto *Select = dyn_cast<SelectInst>(I)) {
274-
Type *CondTy = Select->getCondition()->getType();
275-
for (Use &Op : CondTy->isVectorTy() ? Select->operands()
276-
: drop_begin(Select->operands())) {
277-
auto OpShape = ShapeMap.find(Op);
278-
if (OpShape != ShapeMap.end())
279-
return OpShape->second;
280-
}
281-
}
282-
283-
if (isUniformShape(I)) {
273+
if (isUniformShape(I) || isa<SelectInst>(I)) {
274+
auto Ops = I->operands();
275+
auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
284276
// Find the first operand that has a known shape and use that.
285-
for (auto &Op : I->operands()) {
277+
for (auto &Op : ShapedOps) {
286278
auto OpShape = ShapeMap.find(Op.get());
287279
if (OpShape != ShapeMap.end())
288280
return OpShape->second;
@@ -719,18 +711,12 @@ class LowerMatrixIntrinsics {
719711
} else if (isa<StoreInst>(V)) {
720712
// Nothing to do. We forward-propagated to this so we would just
721713
// backward propagate to an instruction with an already known shape.
722-
} else if (auto *Select = dyn_cast<SelectInst>(V)) {
723-
ShapeInfo Shape = ShapeMap[V];
724-
Type *CondTy = Select->getCondition()->getType();
725-
for (Use &Op : CondTy->isVectorTy() ? Select->operands()
726-
: drop_begin(Select->operands())) {
727-
if (setShapeInfo(Op, Shape))
728-
pushInstruction(Select, WorkList);
729-
}
730-
} else if (isUniformShape(V)) {
714+
} else if (isUniformShape(V) || isa<SelectInst>(V)) {
715+
auto Ops = cast<Instruction>(V)->operands();
716+
auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
731717
// Propagate to all operands.
732718
ShapeInfo Shape = ShapeMap[V];
733-
for (Use &U : cast<Instruction>(V)->operands()) {
719+
for (Use &U : ShapedOps) {
734720
if (setShapeInfo(U.get(), Shape))
735721
pushInstruction(U.get(), WorkList);
736722
}

llvm/test/Transforms/LowerMatrixIntrinsics/select.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
7272
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
7373
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
7474
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
75-
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
76-
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
77-
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
75+
; CHECK-NEXT: [[CONDV:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
7876
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
7977
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
8078
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
79+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
80+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
8181
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
8282
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
8383
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16

0 commit comments

Comments
 (0)