Skip to content

Commit 3d05691

Browse files
authored
[Dispatch] Extend gather fusion pattern (#19862)
This extends `GatherFusionPattern` to support any bit-extend op, instead of just ones that have a single truncate. Needed for #19847 to fix codegen failures. This also improves VAE performance by ~%15 Signed-off-by: Ian Wood <[email protected]>
1 parent b804314 commit 3d05691

File tree

2 files changed

+153
-12
lines changed

2 files changed

+153
-12
lines changed

compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
6363
}
6464

6565
// Check if the producerOp is fusible
66-
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
67-
!isElementwise(producerOp) ||
66+
if (producerOp.getNumResults() != 1 || !isElementwise(producerOp) ||
6867
!IREE::LinalgExt::isBitExtendOp(producerOp)) {
6968
return rewriter.notifyMatchFailure(producerOp,
7069
"producer op is not fusible");
@@ -73,21 +72,29 @@ struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
7372
OpBuilder::InsertionGuard g(rewriter);
7473
rewriter.setInsertionPoint(extractOp);
7574

76-
// Create a new extract op that extracts from the original tensor
77-
// (after the original extract). Clone the producerOp's body into the
78-
// consumerOp, inline the cloned block (erases the block) after the new
79-
// extract, and clean up.
80-
auto newExtractOp = rewriter.create<tensor::ExtractOp>(
81-
extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(),
82-
extractOp.getIndices());
75+
auto result = cast<OpResult>(extractOp.getTensor());
76+
auto resultMap = producerOp.getIndexingMapMatchingResult(result);
77+
SmallVector<Value> extractOps;
78+
for (OpOperand &operand : producerOp->getOpOperands()) {
79+
auto inputMap = producerOp.getMatchingIndexingMap(&operand);
80+
auto composedMap = inputMap.compose(inversePermutation(resultMap));
81+
auto perm = llvm::map_to_vector<4>(
82+
composedMap.getResults(), [](AffineExpr expr) -> int64_t {
83+
return cast<AffineDimExpr>(expr).getPosition();
84+
});
85+
SmallVector<Value, 4> indices = extractOp.getIndices();
86+
indices = applyPermutation(indices, perm);
87+
auto newExtract = rewriter.create<tensor::ExtractOp>(
88+
extractOp.getLoc(), operand.get(), indices);
89+
extractOps.push_back(newExtract);
90+
}
8391
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
8492
consumerOp.getRegion().begin());
8593
Block &clonedBlock = consumerOp.getRegion().front();
8694
auto producerTermOp = clonedBlock.getTerminator();
8795

88-
rewriter.inlineBlockBefore(
89-
&clonedBlock, extractOp->getNextNode(),
90-
{newExtractOp.getResult(), newExtractOp.getResult()});
96+
rewriter.inlineBlockBefore(&clonedBlock, extractOp->getNextNode(),
97+
extractOps);
9198

9299
// Replace the the all references to the original extract result with the
93100
// result from the inlined producerOp.

compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,137 @@ util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x12
321321
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
322322
// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
323323
// CHECK: util.return %[[ATTENTION]]
324+
325+
// -----
326+
327+
util.func public @gather_fusion(%arg0: tensor<2x64x64x640xf16>, %arg1: tensor<2x64x64x640xf16>, %arg2: tensor<2xi64>, %arg3: tensor<640xi64>, %arg4: tensor<128xi64>, %arg5: tensor<640xf16>, %arg6: tensor<f32>) -> tensor<2x128x128x640xi8> {
328+
%cst = arith.constant -1.280000e+02 : f16
329+
%cst_0 = arith.constant 1.270000e+02 : f16
330+
%0 = tensor.empty() : tensor<2x128x128x640xi8>
331+
%1 = tensor.empty() : tensor<2x640x64x64xf32>
332+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<2x64x64x640xf16>, tensor<2x64x64x640xf16>) outs(%1 : tensor<2x640x64x64xf32>) {
333+
^bb0(%in: f16, %in_1: f16, %out: f32):
334+
%4 = arith.addf %in, %in_1 : f16
335+
%5 = arith.extf %4 : f16 to f32
336+
linalg.yield %5 : f32
337+
} -> tensor<2x640x64x64xf32>
338+
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %arg3, %arg4, %arg4, %arg5, %arg6 : tensor<2xi64>, tensor<640xi64>, tensor<128xi64>, tensor<128xi64>, tensor<640xf16>, tensor<f32>) outs(%0 : tensor<2x128x128x640xi8>) {
339+
^bb0(%in: i64, %in_1: i64, %in_2: i64, %in_3: i64, %in_4: f16, %in_5: f32, %out: i8):
340+
%4 = arith.index_cast %in : i64 to index
341+
%5 = arith.index_cast %in_1 : i64 to index
342+
%6 = arith.index_cast %in_2 : i64 to index
343+
%7 = arith.index_cast %in_3 : i64 to index
344+
%extracted = tensor.extract %2[%4, %5, %6, %7] : tensor<2x640x64x64xf32>
345+
%8 = arith.truncf %extracted : f32 to f16
346+
%9 = arith.mulf %8, %in_4 : f16
347+
%10 = arith.truncf %in_5 : f32 to f16
348+
%11 = arith.divf %9, %10 : f16
349+
%12 = math.roundeven %11 : f16
350+
%13 = arith.cmpf ult, %12, %cst : f16
351+
%14 = arith.select %13, %cst, %12 : f16
352+
%15 = arith.cmpf ugt, %14, %cst_0 : f16
353+
%16 = arith.select %15, %cst_0, %14 : f16
354+
%17 = arith.fptosi %16 : f16 to i8
355+
linalg.yield %17 : i8
356+
} -> tensor<2x128x128x640xi8>
357+
util.return %3 : tensor<2x128x128x640xi8>
358+
}
359+
360+
// CHECK-LABEL: util.func public @gather_fusion(
361+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
362+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
363+
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
364+
// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor
365+
// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
366+
// CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: tensor
367+
// CHECK-SAME: %[[ARG6:[A-Za-z0-9]+]]: tensor
368+
// CHECK: %[[GEN:.+]] = linalg.generic
369+
// CHECK-SAME: indexing_maps =
370+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0)>,
371+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
372+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1)>,
373+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2)>,
374+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
375+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> ()>,
376+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
377+
// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
378+
// CHECK: ^bb0(
379+
// CHECK-SAME: %[[IN0:[_a-zA-Z0-9]+]]: i64,
380+
// CHECK-SAME: %[[IN1:[_a-zA-Z0-9]+]]: i64,
381+
// CHECK-SAME: %[[IN2:[_a-zA-Z0-9]+]]: i64,
382+
// CHECK-SAME: %[[IN3:[_a-zA-Z0-9]+]]: i64,
383+
// CHECK-DAG: %[[CAST0:.+]] = arith.index_cast %[[IN0]] : i64 to index
384+
// CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[IN1]] : i64 to index
385+
// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[IN2]] : i64 to index
386+
// CHECK-DAG: %[[CAST3:.+]] = arith.index_cast %[[IN3]] : i64 to index
387+
// CHECK: %[[EXTRACT0:.*]] = tensor.extract %[[ARG0]][%[[CAST0]], %[[CAST2]], %[[CAST3]], %[[CAST1]]] : tensor<2x64x64x640xf16>
388+
// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][%[[CAST0]], %[[CAST2]], %[[CAST3]], %[[CAST1]]] : tensor<2x64x64x640xf16>
389+
// CHECK: %[[ADDF:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : f16
390+
// CHECK: util.return %[[GEN]] : tensor<2x128x128x640xi8>
391+
392+
// -----
393+
394+
util.func public @gather_fusion_compose_maps(%arg0: tensor<2x64x64x640xf16>, %arg1: tensor<2x64x64x640xf16>, %arg2: tensor<2xi64>, %arg3: tensor<640xi64>, %arg4: tensor<128xi64>, %arg5: tensor<640xf16>, %arg6: tensor<f32>) -> tensor<2x128x128x640xi8> {
395+
%cst = arith.constant -1.280000e+02 : f16
396+
%cst_0 = arith.constant 1.270000e+02 : f16
397+
%0 = tensor.empty() : tensor<2x128x128x640xi8>
398+
%1 = tensor.empty() : tensor<2x640x64x64xf32>
399+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<2x64x64x640xf16>, tensor<2x64x64x640xf16>) outs(%1 : tensor<2x640x64x64xf32>) {
400+
^bb0(%in: f16, %in_1: f16, %out: f32):
401+
%4 = arith.addf %in, %in_1 : f16
402+
%5 = arith.extf %4 : f16 to f32
403+
linalg.yield %5 : f32
404+
} -> tensor<2x640x64x64xf32>
405+
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %arg3, %arg4, %arg4, %arg5, %arg6 : tensor<2xi64>, tensor<640xi64>, tensor<128xi64>, tensor<128xi64>, tensor<640xf16>, tensor<f32>) outs(%0 : tensor<2x128x128x640xi8>) {
406+
^bb0(%in: i64, %in_1: i64, %in_2: i64, %in_3: i64, %in_4: f16, %in_5: f32, %out: i8):
407+
%4 = arith.index_cast %in : i64 to index
408+
%5 = arith.index_cast %in_1 : i64 to index
409+
%6 = arith.index_cast %in_2 : i64 to index
410+
%7 = arith.index_cast %in_3 : i64 to index
411+
%extracted = tensor.extract %2[%4, %5, %6, %7] : tensor<2x640x64x64xf32>
412+
%8 = arith.truncf %extracted : f32 to f16
413+
%9 = arith.mulf %8, %in_4 : f16
414+
%10 = arith.truncf %in_5 : f32 to f16
415+
%11 = arith.divf %9, %10 : f16
416+
%12 = math.roundeven %11 : f16
417+
%13 = arith.cmpf ult, %12, %cst : f16
418+
%14 = arith.select %13, %cst, %12 : f16
419+
%15 = arith.cmpf ugt, %14, %cst_0 : f16
420+
%16 = arith.select %15, %cst_0, %14 : f16
421+
%17 = arith.fptosi %16 : f16 to i8
422+
linalg.yield %17 : i8
423+
} -> tensor<2x128x128x640xi8>
424+
util.return %3 : tensor<2x128x128x640xi8>
425+
}
426+
427+
// CHECK-LABEL: util.func public @gather_fusion_compose_maps(
428+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
429+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
430+
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
431+
// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor
432+
// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
433+
// CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: tensor
434+
// CHECK-SAME: %[[ARG6:[A-Za-z0-9]+]]: tensor
435+
// CHECK: %[[GEN:.+]] = linalg.generic
436+
// CHECK-SAME: indexing_maps =
437+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0)>,
438+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
439+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1)>,
440+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2)>,
441+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
442+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> ()>,
443+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
444+
// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
445+
// CHECK: ^bb0(
446+
// CHECK-SAME: %[[IN0:[_a-zA-Z0-9]+]]: i64,
447+
// CHECK-SAME: %[[IN1:[_a-zA-Z0-9]+]]: i64,
448+
// CHECK-SAME: %[[IN2:[_a-zA-Z0-9]+]]: i64,
449+
// CHECK-SAME: %[[IN3:[_a-zA-Z0-9]+]]: i64,
450+
// CHECK-DAG: %[[CAST0:.+]] = arith.index_cast %[[IN0]] : i64 to index
451+
// CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[IN1]] : i64 to index
452+
// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[IN2]] : i64 to index
453+
// CHECK-DAG: %[[CAST3:.+]] = arith.index_cast %[[IN3]] : i64 to index
454+
// CHECK: %[[EXTRACT0:.*]] = tensor.extract %[[ARG0]][%[[CAST0]], %[[CAST2]], %[[CAST3]], %[[CAST1]]] : tensor<2x64x64x640xf16>
455+
// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][%[[CAST0]], %[[CAST3]], %[[CAST2]], %[[CAST1]]] : tensor<2x64x64x640xf16>
456+
// CHECK: %[[ADDF:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : f16
457+
// CHECK: util.return %[[GEN]] : tensor<2x128x128x640xi8>

0 commit comments

Comments
 (0)