Skip to content

Commit b6064bf

Browse files
Groverkssgiacs-epic
authored andcommitted
[VectorDistribution] Allow 0-d vectors in scf.for distribution (iree-org#19317)
Signed-off-by: Giacomo Serafini <[email protected]>
1 parent b1da7a8 commit b6064bf

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,8 @@ struct DistributeScfFor final : OpDistributionPattern<scf::ForOp> {
145145
SmallVector<Value> newInitArgs;
146146
for (Value initArg : forOp.getInitArgs()) {
147147
if (auto vectorInitArg = dyn_cast<VectorValue>(initArg)) {
148-
if (isNonZeroRank(vectorInitArg)) {
149-
initArg =
150-
getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]);
151-
}
148+
initArg =
149+
getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]);
152150
}
153151
newInitArgs.push_back(initArg);
154152
}
@@ -193,14 +191,8 @@ struct DistributeScfFor final : OpDistributionPattern<scf::ForOp> {
193191
SmallVector<Value> operands;
194192
for (Value operand : yieldOp->getOperands()) {
195193
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
196-
// Distributing the operand requires it to have a non-zero rank, meaning
197-
// it must have at least one dimension. If the vector has a non-zero
198-
// rank, the operand is distributed according to the provided layout
199-
// signature.
200-
if (isNonZeroRank(vectorOperand)) {
201-
operand = DistributionPattern::getDistributed(
202-
rewriter, vectorOperand, signature[vectorOperand]);
203-
}
194+
operand = DistributionPattern::getDistributed(rewriter, vectorOperand,
195+
signature[vectorOperand]);
204196
}
205197
operands.push_back(operand);
206198
}
@@ -223,10 +215,8 @@ struct DistributeScfFor final : OpDistributionPattern<scf::ForOp> {
223215
for (auto [bbArg, oldInit] : llvm::zip_equal(bbArgs, oldInits)) {
224216
Value val = bbArg;
225217
if (auto oldVectorInit = dyn_cast<VectorValue>(oldInit)) {
226-
if (isNonZeroRank(oldVectorInit)) {
227-
val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
228-
oldVectorInit.getLoc(), oldVectorInit.getType(), val);
229-
}
218+
val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
219+
oldVectorInit.getLoc(), oldVectorInit.getType(), val);
230220
}
231221
replacements.push_back(val);
232222
}

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,40 @@ func.func @distribute_scf_for(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> v
6262
return %out : vector<16x16xi32>
6363
}
6464

65+
#layout_0d = #iree_vector_ext.nested_layout<
66+
subgroup_tile = [],
67+
batch_tile = [],
68+
outer_tile = [],
69+
thread_tile = [],
70+
element_tile = [],
71+
72+
subgroup_strides = [],
73+
thread_strides = []
74+
>
75+
76+
// CHECK-LABEL: @distribute_scf_for_0d
77+
func.func @distribute_scf_for_0d(%a: vector<i32>, %b: vector<i32>) -> vector<i32> {
78+
%c0 = arith.constant 0 : index
79+
%c1 = arith.constant 1 : index
80+
%c128 = arith.constant 128 : index
81+
%cst_0 = arith.constant 0 : i32
82+
// CHECK: %[[ROOT:.*]] = arith.constant dense<0> : vector<i32>
83+
%root = arith.constant dense<0> : vector<i32>
84+
%rootl = iree_vector_ext.to_layout %root to layout(#layout_0d) : vector<i32>
85+
// CHECK: iter_args(%[[ARG0:.*]] = %[[ROOT]]) -> (vector<i32>)
86+
%out = scf.for %i = %c0 to %c128 step %c1 iter_args(%arg0 = %rootl) -> (vector<i32>) {
87+
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<i32> -> vector<i32>
88+
// CHECK-DAG: %[[C:.*]] = arith.muli %[[ARG0]], %[[B]] {{.*}} : vector<i32>
89+
%c = arith.muli %arg0, %b : vector<i32>
90+
// CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<i32> -> vector<i32>
91+
// CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] {{.*}} : vector<i32>
92+
%d = arith.addi %c, %a : vector<i32>
93+
// CHECK: scf.yield %[[D]] : vector<i32>
94+
scf.yield %d : vector<i32>
95+
}
96+
return %out : vector<i32>
97+
}
98+
6599
builtin.module attributes { transform.with_named_sequence } {
66100
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
67101
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op

0 commit comments

Comments
 (0)