Skip to content

Commit a1ffcc1

Browse files
hanhanWhhkit
authored andcommitted
[DT] Fuse encoding ops more aggressively for multi-use, gather, and slices ops. (iree-org#21830)
The fusion constraint of multi-use dispatch is only required by SetEncoding pass, because it has to move consumer dispatches around. It is not required by encoding fusion, because it is just moving a SetEncoding op into its producer dispatch. The revision also allows the fusion when the dispatch region contains tensor.extract_slice op and iree_linalg_ext.gather ops. It reduces the number of dispatches to 644 in llama fp8 model, the same as without data tiling. The latency drops 25ms, from 378ms to 353ms. | | No Data Tiling | Data Tiling w/o the revision | Data Tiling w/ the revision | | ------------- | ------------- | ------------- | ------------- | | Benchmark latency | 243ms | 378ms | 353ms | | Memory usage (HIP unpooled) | 15.9GB | 31.14GB | 31.11GB | | Number of dispatches | 644 | 741 | 644 | | | No Data Tiling (ms) | Data Tiling w/o the revision | Data Tiling w/ the revision | | ------------- | ------------- | ------------- | ------------- | | dispatch_15_attention_4x8x4xDx128xf8 | 62.29 | 55.35 | 59.21 | | dispatch_20_matmul_like_Dx14336x4096_f8xf8xf32 | 40.13 | 89.14 | 93.72| | dispatch_19_matmul_like_Dx14336x4096_f8xf8xf32 | 28.01 | 44.78 | 44.59 | | dispatch_21_matmul_like_Dx4096x14336_f8xf8xf32 | 27.25 | 40.18 | 39.99 | | dispatch_643_matmul_like_Dx128256x4096_f16xf16xf32 | 17.1 | 29.76 | 29.21 | | dispatch_16_matmul_like_Dx4096x4096_f8xf8xf32 | 8.83 | 17.92 | 17.91 | | dispatch_23_matmul_like_Dx4096x4096_f8xf8xf32 | 9.27 | 16.69 | 16.59 | | encoding_10_encode_Dx4096xf8_to_Dx4096xf8 | - | 32.15 | - | | encoding_6_encode_Dx14336xf32_to_Dx14336xf32 | - | 0.318 | - | --------- Signed-off-by: hanhanW <[email protected]> Signed-off-by: Ivan Ho <[email protected]>
1 parent c08bc69 commit a1ffcc1

File tree

5 files changed

+101
-26
lines changed

5 files changed

+101
-26
lines changed

compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ namespace mlir::iree_compiler::DispatchCreation {
3131
namespace {
3232

3333
/// Return true if the op is fusable with a SetEncodingOp consumer. The op's
34-
/// containing dispatch must contain only reshape ops, encoding ops, linalg ops,
35-
/// and attention ops. Non ShapedType ops (like arith ops, dim ops, etc.) are
36-
/// also allowed.
37-
/// TODO(#20179): It should be done by interface methods.
38-
static bool isFusableWithSetEncoding(Operation *op) {
39-
auto parentRegion = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
34+
/// containing dispatch must contain only:
35+
/// - Reshape ops, encoding ops, linalg ops, gather ops, and attention ops.
36+
/// - Non ShapedType ops, e.g., like arith ops, dim ops, etc.
37+
/// - tensor::ExtractSliceOp is allowed as they can be folded into dispatch
38+
/// tensor load ops.
39+
static bool isFusableWithSetEncoding(Operation *target) {
40+
auto parentRegion = target->getParentOfType<IREE::Flow::DispatchRegionOp>();
4041
// Make sure the dispatch region has only one block.
4142
if (!llvm::hasSingleElement(parentRegion.getBody())) {
4243
return false;
@@ -49,8 +50,9 @@ static bool isFusableWithSetEncoding(Operation *op) {
4950
continue;
5051
}
5152
if (!isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::EmptyOp,
52-
IREE::Encoding::SetEncodingOp, IREE::Encoding::UnsetEncodingOp,
53-
linalg::LinalgOp, IREE::LinalgExt::AttentionOp>(op)) {
53+
tensor::ExtractSliceOp, IREE::Encoding::SetEncodingOp,
54+
IREE::Encoding::UnsetEncodingOp, linalg::LinalgOp,
55+
IREE::LinalgExt::AttentionOp, IREE::LinalgExt::GatherOp>(op)) {
5456
return false;
5557
}
5658
}

compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,8 @@ getProducerDispatchValueAndOpChain(Value operand) {
145145

146146
auto producerDispatch =
147147
dyn_cast<IREE::Flow::DispatchRegionOp>(producerValue.getOwner());
148-
// TODO(MaheshRavishankar): Multi-result producer dispatches can be supported.
149-
// Will require to move the consumer dispatch immediately after the producer
150-
// instead of what is done below and move other operands of the consumer
151-
// dispatch before the producer dispatch.
152148
if (!producerDispatch ||
153-
!llvm::hasSingleElement(producerDispatch.getBody()) ||
154-
producerDispatch->getNumResults() != 1) {
149+
!llvm::hasSingleElement(producerDispatch.getBody())) {
155150
return std::nullopt;
156151
}
157152
if (!llvm::hasSingleElement(producerValue.getUses())) {

compiler/src/iree/compiler/DispatchCreation/FusionUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
2828

2929
/// Returns the closest producer dispatch region op result and the chain of
3030
/// operations being looked past during the traversal to find the producer
31-
/// dispatch. Returns std::nullopt if the dispatch or any ops in the chain have
32-
/// multiple uses.
31+
/// dispatch. Returns std::nullopt if the dispatch can not be found in the
32+
/// chain or any op in the chain is not a reshape-like op.
3333
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
3434
getProducerDispatchValueAndOpChain(Value operand);
3535

compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,16 +392,24 @@ static SmallVector<unsigned> getOperandsToPad(Operation *op) {
392392
OpOperand &operand = op->getOpOperand(operandNum);
393393
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
394394
dispatchAndOpChain = getProducerDispatchValueAndOpChain(operand.get());
395-
if (dispatchAndOpChain.has_value()) {
396-
auto producerDispatch = cast<IREE::Flow::DispatchRegionOp>(
397-
dispatchAndOpChain->first.getOwner());
398-
WalkResult res =
399-
producerDispatch->walk([&](IREE::LinalgExt::AttentionOp op) {
400-
return WalkResult::interrupt();
401-
});
402-
if (res.wasInterrupted()) {
403-
return {};
404-
}
395+
if (!dispatchAndOpChain.has_value()) {
396+
continue;
397+
}
398+
auto producerDispatch = cast<IREE::Flow::DispatchRegionOp>(
399+
dispatchAndOpChain->first.getOwner());
400+
// TODO(MaheshRavishankar): Multi-result producer dispatches can be
401+
// supported. Will require to move the consumer dispatch immediately after
402+
// the producer instead of what is done below and move other operands of the
403+
// consumer dispatch before the producer dispatch.
404+
if (producerDispatch->getNumResults() != 1) {
405+
continue;
406+
}
407+
WalkResult res =
408+
producerDispatch->walk([&](IREE::LinalgExt::AttentionOp op) {
409+
return WalkResult::interrupt();
410+
});
411+
if (res.wasInterrupted()) {
412+
return {};
405413
}
406414
}
407415

compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,24 @@ util.func public @encoding_fusion(%arg0: tensor<128xf32, #encoding0>) -> tensor<
280280

281281
// -----
282282

283+
#encoding = #iree_encoding.testing<>
284+
util.func public @extract_slice_fusion(%arg0: tensor<192x1024x64xf32>) -> tensor<96x512x32xf32, #encoding> {
285+
%0 = flow.dispatch.region -> (tensor<96x512x32xf32>) {
286+
%1 = tensor.extract_slice %arg0[0, 0, 0] [96, 512, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<96x512x32xf32>
287+
flow.return %1 : tensor<96x512x32xf32>
288+
}
289+
%2 = iree_encoding.set_encoding %0 : tensor<96x512x32xf32> -> tensor<96x512x32xf32, #encoding>
290+
util.return %2 : tensor<96x512x32xf32, #encoding>
291+
}
292+
// CHECK-LABEL: @extract_slice_fusion
293+
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
294+
// CHECK: tensor.extract_slice
295+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
296+
// CHECK: flow.return %[[SET_ENCODING]] :
297+
// CHECK: util.return %[[DISPATCH0]]
298+
299+
// -----
300+
283301
#encoding = #iree_encoding.testing<>
284302
util.func public @attention_fusion(
285303
%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>,
@@ -309,3 +327,55 @@ util.func public @attention_fusion(
309327
// CHECK: flow.return %[[SET_ENCODING]] :
310328
// CHECK: }
311329
// CHECK: util.return %[[DISPATCH0]]
330+
331+
// -----
332+
333+
#encoding = #iree_encoding.testing<>
334+
util.func public @gather_fusion(
335+
%source: tensor<10x10xi32>, %indices: tensor<1xi32>) -> tensor<1x10xi32, #encoding> {
336+
%empty = tensor.empty() : tensor<1x10xi32>
337+
%1 = flow.dispatch.region -> (tensor<1x10xi32>) {
338+
%3 = iree_linalg_ext.gather dimension_map = [0]
339+
ins(%source, %indices : tensor<10x10xi32>, tensor<1xi32>)
340+
outs(%empty : tensor<1x10xi32>) -> tensor<1x10xi32>
341+
flow.return %3 : tensor<1x10xi32>
342+
}
343+
%2 = iree_encoding.set_encoding %1 : tensor<1x10xi32> -> tensor<1x10xi32, #encoding>
344+
util.return %2 : tensor<1x10xi32, #encoding>
345+
}
346+
// CHECK-LABEL: @gather_fusion
347+
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
348+
// CHECK: iree_linalg_ext.gather
349+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
350+
// CHECK: flow.return %[[SET_ENCODING]] :
351+
// CHECK: util.return %[[DISPATCH0]]
352+
353+
// -----
354+
355+
#map = affine_map<(d0, d1) -> (d0, d1)>
356+
#encoding = #iree_encoding.testing<>
357+
util.func public @multi_result_fusion(%arg0: tensor<123x456xf32>) -> (tensor<123x456xf32>, tensor<123x456xf32, #encoding>) {
358+
%cst = arith.constant 0.000000e+00 : f32
359+
%0 = tensor.empty() : tensor<123x456xf32>
360+
%1:2 = flow.dispatch.region -> (tensor<123x456xf32>, tensor<123x456xf32>) {
361+
%3:2 = linalg.generic {
362+
indexing_maps = [#map, #map, #map, #map],
363+
iterator_types = ["parallel", "parallel"]}
364+
ins(%arg0, %arg0 : tensor<123x456xf32>, tensor<123x456xf32>)
365+
outs(%0, %0 : tensor<123x456xf32>, tensor<123x456xf32>) {
366+
^bb0(%in: f32, %in_0: f32, %out: f32, %out2: f32):
367+
%4 = arith.addf %in, %in_0 : f32
368+
%5 = arith.mulf %in, %in_0 : f32
369+
linalg.yield %4, %5 : f32, f32
370+
} -> (tensor<123x456xf32>, tensor<123x456xf32>)
371+
flow.return %3#0, %3#1 : tensor<123x456xf32>, tensor<123x456xf32>
372+
}
373+
%2 = iree_encoding.set_encoding %1#1 : tensor<123x456xf32> -> tensor<123x456xf32, #encoding>
374+
util.return %1#0, %2 : tensor<123x456xf32>, tensor<123x456xf32, #encoding>
375+
}
376+
// CHECK-LABEL: @multi_result_fusion
377+
// CHECK: %[[DISPATCH0:.+]]:2 = flow.dispatch.region
378+
// CHECK: %[[ELEM:.+]]:2 = linalg.generic
379+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[ELEM]]#1
380+
// CHECK: flow.return %[[ELEM]]#0, %[[SET_ENCODING]]
381+
// CHECK: util.return %[[DISPATCH0]]#0, %[[DISPATCH0]]#1

0 commit comments

Comments
 (0)