Skip to content

Commit eef2c3a

Browse files
Groverksssaienduri
andauthored
[GPU] Do not generate insert_strided_slice for 0-d vectors (#19149)
Co-authored-by: saienduri <[email protected]>
1 parent bf711a1 commit eef2c3a

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,20 @@ struct DistributeTransferRead final
200200
rewriter, indices, offsets, vectorLayout, readOp.getPermutationMap(),
201201
warpIndices, threadIndices);
202202

203-
Value slicedRead = rewriter.create<vector::TransferReadOp>(
203+
VectorValue slicedRead = rewriter.create<vector::TransferReadOp>(
204204
readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
205205
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
206206
readOp.getInBoundsAttr());
207207

208-
acc = rewriter.create<vector::InsertStridedSliceOp>(
209-
readOp.getLoc(), slicedRead, acc, offsets, strides);
208+
if (acc.getType().getRank() == 0) {
209+
// TODO: This should really be a folding pattern in
210+
// insert_strided_slice, but instead insert_strided_slice just doesn't
211+
// support 0-d vectors...
212+
acc = slicedRead;
213+
} else {
214+
acc = rewriter.create<vector::InsertStridedSliceOp>(
215+
readOp.getLoc(), slicedRead, acc, offsets, strides);
216+
}
210217
}
211218

212219
replaceOpWithDistributedValues(rewriter, readOp, acc);

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,42 @@ builtin.module attributes { transform.with_named_sequence } {
331331

332332
// -----
333333

334+
#layout = #iree_vector_ext.nested_layout<
335+
subgroup_tile = [],
336+
batch_tile = [],
337+
outer_tile = [],
338+
thread_tile = [],
339+
element_tile = [],
340+
341+
subgroup_strides = [],
342+
thread_strides = []
343+
>
344+
345+
// CHECK-LABEL: @distribute_transfer_read_0d
346+
func.func @distribute_transfer_read_0d(%arg0: memref<128xf16>) -> vector<f16> {
347+
%c0 = arith.constant 0 : index
348+
%cst = arith.constant 0.0 : f16
349+
%root = vector.transfer_read %arg0[%c0], %cst
350+
{in_bounds = []} : memref<128xf16>, vector<f16>
351+
%rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<f16>
352+
func.return %rootl : vector<f16>
353+
}
354+
355+
356+
builtin.module attributes { transform.with_named_sequence } {
357+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
358+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
359+
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
360+
transform.yield
361+
}
362+
}
363+
364+
// CHECK: %[[RD:.+]] = vector.transfer_read %{{.*}}[%c0]
365+
// CHECK-SAME: memref<128xf16>, vector<f16>
366+
// CHECK: iree_vector_ext.to_simd %[[RD]]
367+
368+
// -----
369+
334370
#layout_row_major = #iree_vector_ext.nested_layout<
335371
subgroup_tile = [1, 1],
336372
batch_tile = [2, 2],

0 commit comments

Comments
 (0)