Skip to content

Commit c0c63f3

Browse files
committed
select with mismatched shape
1 parent e952bcb commit c0c63f3

File tree

2 files changed

+56
-24
lines changed

2 files changed

+56
-24
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ computeShapeInfoForInst(Instruction *I,
271271
}
272272

273273
if (auto *Select = dyn_cast<SelectInst>(I)) {
274-
for (Use &Op : Select->getCondition()->getType()->isVectorTy() ? I->operands() : drop_begin(I->operands())) {
274+
Type *CondTy = Select->getCondition()->getType();
275+
for (Use &Op : CondTy->isVectorTy() ? Select->operands()
276+
: drop_begin(Select->operands())) {
275277
auto OpShape = ShapeMap.find(Op);
276278
if (OpShape != ShapeMap.end())
277279
return OpShape->second;
@@ -719,10 +721,12 @@ class LowerMatrixIntrinsics {
719721
// backward propagate to an instruction with an already known shape.
720722
} else if (auto *Select = dyn_cast<SelectInst>(V)) {
721723
ShapeInfo Shape = ShapeMap[V];
722-
if (setShapeInfo(Select->getOperand(1), Shape))
723-
pushInstruction(Select, WorkList);
724-
if (setShapeInfo(Select->getOperand(2), Shape))
725-
pushInstruction(Select, WorkList);
724+
Type *CondTy = Select->getCondition()->getType();
725+
for (Use &Op : CondTy->isVectorTy() ? Select->operands()
726+
: drop_begin(Select->operands())) {
727+
if (setShapeInfo(Op, Shape))
728+
pushInstruction(Select, WorkList);
729+
}
726730
} else if (isUniformShape(V)) {
727731
// Propagate to all operands.
728732
ShapeInfo Shape = ShapeMap[V];

llvm/test/Transforms/LowerMatrixIntrinsics/select.ll

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,39 +67,41 @@ define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
6767
ret void
6868
}
6969

70-
define void @select_2x2_vcond(<4 x i1> %cond, ptr %lhs, ptr %rhs, ptr %out) {
71-
; CHECK-LABEL: @select_2x2_vcond(
70+
define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
71+
; CHECK-LABEL: @select_2x2_vcond_shape1(
7272
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
7373
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
7474
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
75-
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
76-
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
77-
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
78-
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COND:%.*]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
79-
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <4 x i1> [[COND]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
80-
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
81-
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT5]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
75+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
76+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
77+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
78+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
79+
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
80+
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
81+
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
82+
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
8283
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
83-
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[OUT]], i64 2
84-
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP6]], align 8
84+
; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
85+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
8586
; CHECK-NEXT: ret void
8687
;
8788
%lhsv = load <4 x float>, ptr %lhs
89+
%condv = load <4 x i1>, ptr %cond
8890
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
89-
%op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv
91+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
9092
store <4 x float> %op, ptr %out
9193
ret void
9294
}
9395

94-
define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) {
95-
; CHECK-LABEL: @select_2x2_vcond_shape(
96+
define void @select_2x2_vcond_shape2(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
97+
; CHECK-LABEL: @select_2x2_vcond_shape2(
9698
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
9799
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
98100
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
99-
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
100-
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
101+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[COND:%.*]], align 1
102+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[COND]], i64 2
101103
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
102-
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS]], align 4
104+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
103105
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS]], i64 2
104106
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
105107
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
@@ -110,9 +112,35 @@ define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) {
110112
; CHECK-NEXT: ret void
111113
;
112114
%lhsv = load <4 x float>, ptr %lhs
113-
%cond = call <4 x i1> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
115+
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 2, i1 false, i32 2, i32 2)
116+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
117+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
118+
store <4 x float> %op, ptr %out
119+
ret void
120+
}
121+
122+
define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
123+
; CHECK-LABEL: @select_2x2_vcond_shape3(
124+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
125+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
126+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
127+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
128+
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
129+
; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[RHS]], i64 2
130+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[VEC_GEP4]], align 4
131+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
132+
; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
133+
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD3]]
134+
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT6]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD5]]
135+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
136+
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr float, ptr [[OUT]], i64 2
137+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP7]], align 8
138+
; CHECK-NEXT: ret void
139+
;
140+
%lhsv = load <4 x float>, ptr %lhs
141+
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
114142
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
115-
%op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv
143+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
116144
store <4 x float> %op, ptr %out
117145
ret void
118146
}

0 commit comments

Comments
 (0)