Skip to content

Commit 3470dbb

Browse files
authored
[Dispatch Creation] Fix GatherFusionPattern crash (#20887)
Fixes crash when one of the producer's maps is a projected permutation and fixes bug by replacing `linalg.index` ops in the producer with the operands of the `tensor.extract` Signed-off-by: Ian Wood <[email protected]>
1 parent 8689052 commit 3470dbb

File tree

2 files changed

+114
-13
lines changed

2 files changed

+114
-13
lines changed

compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,23 @@ namespace mlir::iree_compiler::DispatchCreation {
2929
#include "iree/compiler/DispatchCreation/Passes.h.inc"
3030

3131
namespace {
32-
3332
struct ElementwiseOpFusionPass final
3433
: public impl::ElementwiseOpFusionPassBase<ElementwiseOpFusionPass> {
3534
using Base::Base;
3635
void runOnOperation() override;
3736
};
37+
} // namespace
38+
39+
template <typename T>
40+
static SmallVector<T> applyProjectedPermutation(const SmallVectorImpl<T> &input,
41+
ArrayRef<int64_t> projPerm) {
42+
SmallVector<T> result;
43+
result.reserve(projPerm.size());
44+
for (int64_t idx : projPerm) {
45+
result.push_back(input[idx]);
46+
}
47+
return result;
48+
}
3849

3950
//===----------------------------------------------------------------------===//
4051
// GatherFusionPattern
@@ -44,6 +55,7 @@ struct ElementwiseOpFusionPass final
4455
// cannot be fused because it there is no producer-consumer
4556
// relationship between the two generics. This is because the indexing
4657
// is not affine (index values come from a tensor).
58+
namespace {
4759
struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
4860
using OpRewritePattern::OpRewritePattern;
4961
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
@@ -83,14 +95,23 @@ struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
8395
return cast<AffineDimExpr>(expr).getPosition();
8496
});
8597
SmallVector<Value, 4> indices = extractOp.getIndices();
86-
indices = applyPermutation(indices, perm);
98+
indices = applyProjectedPermutation(indices, perm);
8799
auto newExtract = rewriter.create<tensor::ExtractOp>(
88100
extractOp.getLoc(), operand.get(), indices);
89101
extractOps.push_back(newExtract);
90102
}
91103
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
92104
consumerOp.getRegion().begin());
93105
Block &clonedBlock = consumerOp.getRegion().front();
106+
107+
// Replace `linalg.index` ops with the value of the index from `indices`.
108+
SmallVector<Value, 4> indices = extractOp.getIndices();
109+
indices = applyPermutationMap(resultMap, ArrayRef(indices));
110+
SmallVector<linalg::IndexOp> indexOps(
111+
clonedBlock.getOps<linalg::IndexOp>());
112+
for (linalg::IndexOp indexOp : indexOps) {
113+
rewriter.replaceOp(indexOp, indices[indexOp.getDim()]);
114+
}
94115
auto producerTermOp = clonedBlock.getTerminator();
95116

96117
rewriter.inlineBlockBefore(&clonedBlock, extractOp->getNextNode(),
@@ -105,7 +126,6 @@ struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
105126
return success();
106127
}
107128
};
108-
109129
} // namespace
110130

111131
void ElementwiseOpFusionPass::runOnOperation() {

compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ util.func public @transpose_attention(%arg0: tensor<4x64x32x128xf16>, %arg1: ten
2727
%collapsed = tensor.collapse_shape %7 [[0], [1], [2, 3]] : tensor<4x64x32x128xf16> into tensor<4x64x4096xf16>
2828
util.return %collapsed : tensor<4x64x4096xf16>
2929
}
30-
3130
// CHECK-LABEL: util.func public @transpose_attention
3231
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
3332
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -76,7 +75,6 @@ util.func public @transposed_attention_masked(%arg0: tensor<4x64x32x128xf16>, %a
7675
%collapsed = tensor.collapse_shape %8 [[0], [1], [2, 3]] : tensor<4x64x32x128xf16> into tensor<4x64x4096xf16>
7776
util.return %collapsed : tensor<4x64x4096xf16>
7877
}
79-
8078
// CHECK-LABEL: util.func public @transposed_attention_masked
8179
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
8280
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -115,7 +113,6 @@ util.func public @transpose_matmul(%arg0 : tensor<100x100xf16>, %arg1 : tensor<1
115113
} -> tensor<100x100xf16>
116114
util.return %4 : tensor<100x100xf16>
117115
}
118-
119116
// CHECK-LABEL: util.func public @transpose_matmul
120117
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
121118
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -156,7 +153,6 @@ util.func public @fuse_generic_gather(
156153
} -> tensor<4x?x4096xf32>
157154
util.return %16 : tensor<4x?x4096xf32>
158155
}
159-
160156
// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
161157
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
162158
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
@@ -198,7 +194,6 @@ util.func public @fuse_generic_gather2(
198194
} -> tensor<4x?x4096xf32>
199195
util.return %16 : tensor<4x?x4096xf32>
200196
}
201-
202197
// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
203198
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
204199
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
@@ -237,7 +232,6 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf
237232
} -> tensor<2x10x4096x64xf16>
238233
util.return %attention : tensor<2x10x4096x64xf16>
239234
}
240-
241235
// CHECK-LABEL: util.func public @fuse_transpose_attention_to_producer
242236
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
243237
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -274,7 +268,6 @@ util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %a
274268
} -> tensor<4x8x4x?x32x128xf16>
275269
util.return %1 : tensor<4x8x4x?x32x128xf16>
276270
}
277-
278271
// CHECK-LABEL: func public @fuse_attention_with_broadcast
279272
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
280273
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
@@ -305,7 +298,6 @@ util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x12
305298
} -> tensor<4x8x4x?x32x128xf16>
306299
util.return %1 : tensor<4x8x4x?x32x128xf16>
307300
}
308-
309301
// CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose
310302
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
311303
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
@@ -356,7 +348,6 @@ util.func public @gather_fusion(%arg0: tensor<2x64x64x640xf16>, %arg1: tensor<2x
356348
} -> tensor<2x128x128x640xi8>
357349
util.return %3 : tensor<2x128x128x640xi8>
358350
}
359-
360351
// CHECK-LABEL: util.func public @gather_fusion(
361352
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
362353
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -423,7 +414,6 @@ util.func public @gather_fusion_compose_maps(%arg0: tensor<2x64x64x640xf16>, %ar
423414
} -> tensor<2x128x128x640xi8>
424415
util.return %3 : tensor<2x128x128x640xi8>
425416
}
426-
427417
// CHECK-LABEL: util.func public @gather_fusion_compose_maps(
428418
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
429419
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -455,3 +445,94 @@ util.func public @gather_fusion_compose_maps(%arg0: tensor<2x64x64x640xf16>, %ar
455445
// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][%[[CAST0]], %[[CAST3]], %[[CAST2]], %[[CAST1]]] : tensor<2x64x64x640xf16>
456446
// CHECK: %[[ADDF:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : f16
457447
// CHECK: util.return %[[GEN]] : tensor<2x128x128x640xi8>
448+
449+
// -----
450+
451+
util.func public @gather_0d_producer(%arg0 : tensor<f16>, %arg1 : tensor<100xindex>, %arg2 : tensor<256xf16>) -> (tensor<100xf32>) {
452+
%empty0 = tensor.empty() : tensor<256xf32>
453+
%0 = linalg.generic {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg2 : tensor<f16>, tensor<256xf16>) outs(%empty0 : tensor<256xf32>) {
454+
^bb0(%in: f16, %in0 : f16, %out: f32):
455+
%0 = arith.extf %in : f16 to f32
456+
%1 = arith.extf %in0 : f16 to f32
457+
%2 = arith.addf %0, %1 : f32
458+
linalg.yield %2 : f32
459+
} -> tensor<256xf32>
460+
%empty1 = tensor.empty() : tensor<100xf32>
461+
%gather = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg1: tensor<100xindex>) outs(%empty1 : tensor<100xf32>) {
462+
^bb0(%in: index, %out: f32):
463+
%1 = tensor.extract %0[%in] : tensor<256xf32>
464+
linalg.yield %1 : f32
465+
} -> tensor<100xf32>
466+
util.return %gather : tensor<100xf32>
467+
}
468+
// CHECK-LABEL: util.func public @gather_0d_producer(
469+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
470+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
471+
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
472+
// CHECK: %[[GATHER:.+]] = linalg.generic
473+
// CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
474+
// CHECK-NEXT: ^bb0(%[[IN:.+]]: index
475+
// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract %[[ARG0]][]
476+
// CHECK-DAG: %[[EXTRACT1:.+]] = tensor.extract %[[ARG2]][%[[IN]]]
477+
// CHECK: return %[[GATHER]]
478+
479+
// -----
480+
481+
util.func public @gather_replace_linalg_index(%arg0 : tensor<256x256xf16>, %arg1 : tensor<100xindex>) -> (tensor<100xf32>) {
482+
%empty0 = tensor.empty() : tensor<256x256xf32>
483+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<256x256xf16>) outs(%empty0 : tensor<256x256xf32>) {
484+
^bb0(%in: f16, %out: f32):
485+
%0 = arith.extf %in : f16 to f32
486+
%1 = linalg.index 1 : index
487+
%2 = arith.index_cast %1 : index to i32
488+
%3 = arith.uitofp %2 : i32 to f32
489+
%4 = arith.addf %0, %3 : f32
490+
linalg.yield %4 : f32
491+
} -> tensor<256x256xf32>
492+
%empty1 = tensor.empty() : tensor<100xf32>
493+
%gather = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg1: tensor<100xindex>) outs(%empty1 : tensor<100xf32>) {
494+
^bb0(%in: index, %out: f32):
495+
%cst0 = arith.constant 0 : index
496+
%1 = tensor.extract %0[%cst0, %in] : tensor<256x256xf32>
497+
linalg.yield %1 : f32
498+
} -> tensor<100xf32>
499+
util.return %gather : tensor<100xf32>
500+
}
501+
// CHECK-LABEL: util.func public @gather_replace_linalg_index(
502+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
503+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
504+
// CHECK: %[[GATHER:.+]] = linalg.generic
505+
// CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
506+
// CHECK-NEXT: ^bb0(%[[IN:.+]]: index
507+
// CHECK: arith.index_cast %[[IN]]
508+
// CHECK: return %[[GATHER]]
509+
510+
// -----
511+
512+
util.func public @gather_replace_linalg_index_transpose(%arg0 : tensor<256x256xf16>, %arg1 : tensor<100xindex>, %arg2 : index) -> (tensor<100xf32>) {
513+
%empty0 = tensor.empty() : tensor<256x256xf32>
514+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<256x256xf16>) outs(%empty0 : tensor<256x256xf32>) {
515+
^bb0(%in: f16, %out: f32):
516+
%0 = arith.extf %in : f16 to f32
517+
%1 = linalg.index 1 : index
518+
%2 = arith.index_cast %1 : index to i32
519+
%3 = arith.uitofp %2 : i32 to f32
520+
%4 = arith.addf %0, %3 : f32
521+
linalg.yield %4 : f32
522+
} -> tensor<256x256xf32>
523+
%empty1 = tensor.empty() : tensor<100xf32>
524+
%gather = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg1: tensor<100xindex>) outs(%empty1 : tensor<100xf32>) {
525+
^bb0(%in: index, %out: f32):
526+
%1 = tensor.extract %0[%arg2, %in] : tensor<256x256xf32>
527+
linalg.yield %1 : f32
528+
} -> tensor<100xf32>
529+
util.return %gather : tensor<100xf32>
530+
}
531+
// CHECK-LABEL: util.func public @gather_replace_linalg_index_transpose(
532+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
533+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
534+
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: index
535+
// CHECK: %[[GATHER:.+]] = linalg.generic
536+
// CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
537+
// CHECK: arith.index_cast %[[ARG2]]
538+
// CHECK: return %[[GATHER]]

0 commit comments

Comments
 (0)