Skip to content

Commit a09be42

Browse files
authored
[Dispatch] Bubble extract_slice through all parallel generics (#20161)
Fixes llama fp8 perf regression introduced by #20106. The PR stopped the linalg.generic from getting hoisted. This was causing a broadcast to get fused and `tensor<1x1x131072x131072xi1>` to be recomputed on each prefill call. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent ec128bf commit a09be42

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> {
5858
"expected generic op to have all projected permutation maps");
5959
}
6060

61-
if (genericOp.hasIndexSemantics()) {
62-
return rewriter.notifyMatchFailure(
63-
genericOp, "pattern doesn't support index semantics");
64-
}
65-
6661
Value replacement;
6762
linalg::GenericOp swappedOp;
6863
{

compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,26 @@ func.func @bubble_up_extract_slice_single_use(%arg0: tensor<131072xi64>, %arg1:
141141
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] :
142142
// CHECK-SAME: outs(%[[EMPTY]] :
143143
// CHECK: return %[[GENERIC]]
144+
145+
// -----
146+
147+
func.func @bubble_extract_broadcast(%arg0: tensor<1x1x131072xi64>, %arg2: index) -> tensor<?x?xi1> {
148+
%empty = tensor.empty() : tensor<1x1x131072x131072xi1>
149+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0: tensor<1x1x131072xi64>) outs(%empty: tensor<1x1x131072x131072xi1>) {
150+
^bb0(%in: i64, %out: i1):
151+
%899 = linalg.index 3 : index
152+
%900 = arith.index_cast %899 : index to i64
153+
%901 = arith.cmpi sge, %900, %in : i64
154+
linalg.yield %901 : i1
155+
} -> tensor<1x1x131072x131072xi1>
156+
%extracted_slice = tensor.extract_slice %0[0, 0, 0, 0] [1, 1, %arg2, %arg2] [1, 1, 1, 1] : tensor<1x1x131072x131072xi1> to tensor<?x?xi1>
157+
return %extracted_slice : tensor<?x?xi1>
158+
}
159+
// CHECK-LABEL: func @bubble_extract_broadcast
160+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x131072xi64>
161+
// CHECK-SAME: %[[ARG2:.+]]: index
162+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]]
163+
// CHECK-SAME: tensor<1x1x131072xi64> to tensor<?xi64>
164+
// CHECK: %[[GENERIC:.+]] = linalg.generic
165+
// CHECK-SAME: ins(%[[EXTRACT]] : tensor<?xi64>)
166+
// CHECK: return %[[GENERIC]]

0 commit comments

Comments
 (0)