Skip to content

Commit 1c10dac

Browse files
authored
Fix result type mismatch of tiled ops (#20211)
Use OpFoldResult wherever possible in LinalgExt tiling implementation. This helps in avoiding result type mismtach issues that otherwise may occur when the tiled producers of slices are fused. Fixes issue: #17526 --------- Signed-off-by: Praveen G <[email protected]>
1 parent 4451b8b commit 1c10dac

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,38 @@ func.func @ukernel_generic(%arg0: tensor<1x192x1x16xf32>, %arg1: tensor<1x768x1x
214214
// CHECK: linalg.generic
215215
// CHECK-SAME: ins(%[[UK_SLICE]], %[[ARG3_SLICE]]
216216
// CHECK-SAME: outs(%[[ITER_SLICE]]
217+
218+
// -----
219+
220+
func.func @tile_linalg_ext_scan() attributes {translation_info = #iree_codegen.translation_info<pipeline = CPUDefault>} {
221+
%c0_i64 = arith.constant 0 : i64
222+
%c0 = arith.constant 0 : index
223+
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128x2xf32>>
224+
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<128x2xi64>>
225+
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x2xf32>> -> tensor<128x2xf32>
226+
%3 = tensor.empty() : tensor<2xi64>
227+
%4 = tensor.empty() : tensor<128x2xi64>
228+
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<128x2xf32>) outs(%4 : tensor<128x2xi64>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1], [0, 1], [0, 0], [1, 0]]>} {
229+
^bb0(%in: f32, %out: i64):
230+
%9 = arith.fptosi %in : f32 to i64
231+
linalg.yield %9 : i64
232+
} -> tensor<128x2xi64>
233+
%6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0], [0], [0], [1]]>} ins(%c0_i64 : i64) outs(%3 : tensor<2xi64>) -> tensor<2xi64>
234+
%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1], [0, 1], [0, 0], [1, 0]]>} ins(%c0_i64 : i64) outs(%4 : tensor<128x2xi64>) -> tensor<128x2xi64>
235+
%8:2 = iree_linalg_ext.scan {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1], [0, 1]]>} dimension(0) inclusive(true) ins(%5 : tensor<128x2xi64>) outs(%7, %6 : tensor<128x2xi64>, tensor<2xi64>) {
236+
^bb0(%arg0: i64, %arg1: i64):
237+
%9 = arith.addi %arg0, %arg1 : i64
238+
iree_linalg_ext.yield %9 : i64
239+
} -> tensor<128x2xi64>, tensor<2xi64>
240+
flow.dispatch.tensor.store %8#0, %1, offsets = [0, 0], sizes = [128, 2], strides = [1, 1] : tensor<128x2xi64> -> !flow.dispatch.tensor<writeonly:tensor<128x2xi64>>
241+
return
242+
}
243+
// CHECK-LABEL: func.func @tile_linalg_ext_scan
244+
// CHECK: scf.for
245+
// CHECK-SAME: {
246+
// CHECK: linalg.generic
247+
// CHECK: linalg.fill
248+
// CHECK: linalg.fill
249+
// CHECK: iree_linalg_ext.scan
250+
// CHECK: scf.yield
251+
// CHECK: }

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ SmallVector<utils::IteratorType> ScatterOp::getLoopIteratorTypes() {
8383

8484
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
8585
Location loc = getLoc();
86-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
87-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
86+
OpFoldResult zero = builder.getIndexAttr(0);
87+
OpFoldResult one = builder.getIndexAttr(1);
8888
SmallVector<Range> ranges;
8989
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
9090
OpFoldResult ub = getDim(builder, loc, getUpdates(), dim);
@@ -293,12 +293,12 @@ SmallVector<Range> SortOp::getIterationDomain(OpBuilder &builder) {
293293
int64_t operandRank = getOperandRank();
294294
SmallVector<Range> loopBounds(operandRank);
295295
Location loc = getLoc();
296-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
297-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
296+
OpFoldResult zero = builder.getIndexAttr(0);
297+
OpFoldResult one = builder.getIndexAttr(1);
298298
Value source = getOperand(0);
299299
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
300300
loopBounds[dim].offset = zero;
301-
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
301+
loopBounds[dim].size = getDim(builder, loc, source, dim);
302302
loopBounds[dim].stride = one;
303303
}
304304
return loopBounds;
@@ -435,16 +435,16 @@ SmallVector<Range> FftOp::getIterationDomain(OpBuilder &builder) {
435435
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
436436
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
437437
for (auto [idx, val] : llvm::enumerate(getOperandShape().drop_back())) {
438-
Value size;
438+
OpFoldResult size;
439439
if (ShapedType::isDynamic(val)) {
440440
size = getDimValue(builder, loc, getReal(), idx);
441441
} else {
442-
size = builder.create<arith::ConstantIndexOp>(loc, val);
442+
size = builder.getIndexAttr(val);
443443
}
444444
res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one});
445445
}
446446

447-
Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1);
447+
OpFoldResult size = getDim(builder, loc, getReal(), getOperandRank() - 1);
448448
Value stride = builder.create<arith::ShLIOp>(loc, one, getStage());
449449
res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride});
450450
return res;
@@ -643,12 +643,12 @@ SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
643643
int64_t operandRank = getOperandRank();
644644
SmallVector<Range> loopBounds(operandRank);
645645
Location loc = getLoc();
646-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
647-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
646+
OpFoldResult zero = builder.getIndexAttr(0);
647+
OpFoldResult one = builder.getIndexAttr(1);
648648
Value source = getInput();
649649
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
650650
loopBounds[dim].offset = zero;
651-
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
651+
loopBounds[dim].size = getDim(builder, loc, source, dim);
652652
loopBounds[dim].stride = one;
653653
}
654654
return loopBounds;
@@ -836,12 +836,12 @@ SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
836836
int64_t operandRank = getInputRank();
837837
SmallVector<Range> loopBounds(operandRank);
838838
Location loc = getLoc();
839-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
840-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
839+
OpFoldResult zero = builder.getIndexAttr(0);
840+
OpFoldResult one = builder.getIndexAttr(1);
841841
Value source = getValues();
842842
for (auto [idx, val] : llvm::enumerate(getInputType().getShape())) {
843843
loopBounds[idx].offset = zero;
844-
loopBounds[idx].size = getDimValue(builder, loc, source, idx);
844+
loopBounds[idx].size = getDim(builder, loc, source, idx);
845845
loopBounds[idx].stride = one;
846846
}
847847
return loopBounds;
@@ -1285,7 +1285,7 @@ SmallVector<Range> Im2colOp::getIterationDomain(OpBuilder &builder) {
12851285
SmallVector<Range> loopBounds(getOutputRank());
12861286
for (int dim = 0; dim < getOutputRank(); ++dim) {
12871287
loopBounds[dim].offset = zero;
1288-
loopBounds[dim].size = getDimValue(builder, loc, dest, dim);
1288+
loopBounds[dim].size = getDim(builder, loc, dest, dim);
12891289
loopBounds[dim].stride = one;
12901290
}
12911291
return loopBounds;
@@ -1391,15 +1391,15 @@ LogicalResult Im2colOp::getResultTilePosition(
13911391
SmallVector<Range>
13921392
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
13931393
Location loc = getLoc();
1394-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1395-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
1394+
OpFoldResult zero = builder.getIndexAttr(0);
1395+
OpFoldResult one = builder.getIndexAttr(1);
13961396
Value dest = getOutput();
13971397
SmallVector<Range> loopBounds(getIterationDomainRank());
13981398
int count = 0;
13991399
for (auto dim :
14001400
llvm::seq<int64_t>(getImageDimensions().size(), getOutputRank())) {
14011401
loopBounds[count].offset = zero;
1402-
loopBounds[count].size = getDimValue(builder, loc, dest, dim);
1402+
loopBounds[count].size = getDim(builder, loc, dest, dim);
14031403
loopBounds[count].stride = one;
14041404
count++;
14051405
}
@@ -1537,7 +1537,7 @@ WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
15371537
for (auto dim : llvm::seq<int64_t>(numKernelDims, outRank)) {
15381538
int64_t loopDim = dim - numKernelDims;
15391539
loopBounds[loopDim].offset = zero;
1540-
loopBounds[loopDim].size = getDimValue(builder, loc, source, dim);
1540+
loopBounds[loopDim].size = getDim(builder, loc, source, dim);
15411541
loopBounds[loopDim].stride = one;
15421542
}
15431543
return loopBounds;
@@ -1640,15 +1640,15 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
16401640
SmallVector<Range>
16411641
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
16421642
Location loc = getLoc();
1643-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1644-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
1643+
OpFoldResult zero = builder.getIndexAttr(0);
1644+
OpFoldResult one = builder.getIndexAttr(1);
16451645
Value source = getInput();
16461646
SmallVector<Range> loopBounds(getIterationDomainRank());
16471647
int count = 0;
16481648
for (auto dim :
16491649
llvm::seq<int64_t>(getImageDimensions().size(), getInputRank())) {
16501650
loopBounds[count].offset = zero;
1651-
loopBounds[count].size = getDimValue(builder, loc, source, dim);
1651+
loopBounds[count].size = getDim(builder, loc, source, dim);
16521652
loopBounds[count].stride = one;
16531653
count++;
16541654
}

0 commit comments

Comments
 (0)