Skip to content

Commit 5767be3

Browse files
authored
Reland "Support fusing broadcast transposes with attention" (#19962)
Reland the changes to fold attention ops with broadcasts with a small tweak to `AttentionOpDetail` so that the batch dimensions are properly computed when an operand is broadcasted. Original PR #19828 Revert PR #19835 Issue causing revert #19833 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent d11b876 commit 5767be3

File tree

3 files changed

+88
-24
lines changed

3 files changed

+88
-24
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
1010
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
1111
#include "llvm/ADT/STLExtras.h"
12+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1213
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -58,42 +59,44 @@ struct FuseTransposeWithAttentionOp final
5859

5960
LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp,
6061
PatternRewriter &rewriter) const override {
61-
OpOperand *transposeOperand = nullptr;
62-
linalg::LinalgOp transposeOp;
62+
OpOperand *operand = nullptr;
63+
linalg::LinalgOp producer;
6364
for (OpOperand *input : attentionOp.getDpsInputOperands()) {
6465
if (controlFn && !controlFn(input)) {
6566
continue;
6667
}
6768

68-
auto maybeTransposeOp = input->get().getDefiningOp<linalg::LinalgOp>();
69-
if (maybeTransposeOp && isaTranspose(maybeTransposeOp) &&
70-
maybeTransposeOp->hasOneUse()) {
71-
transposeOp = maybeTransposeOp;
72-
transposeOperand = input;
69+
auto maybeProducer = input->get().getDefiningOp<linalg::GenericOp>();
70+
if (maybeProducer && maybeProducer.isSingleYieldOp()) {
71+
producer = maybeProducer;
72+
operand = input;
7373
break;
7474
}
7575
}
76-
if (!transposeOperand) {
77-
return rewriter.notifyMatchFailure(attentionOp, "no transpose operand");
76+
if (!operand) {
77+
return rewriter.notifyMatchFailure(attentionOp, "no operand found");
7878
}
7979

80-
int64_t inputIndex = transposeOperand->getOperandNumber();
81-
SmallVector<int64_t> perm = getPermutation(transposeOp);
82-
auto invPerm = invertPermutationVector(perm);
80+
int64_t inputIndex = operand->getOperandNumber();
81+
82+
auto producerMaps = producer.getIndexingMapsArray();
83+
AffineMap producerInputMap = producerMaps[0];
84+
AffineMap producerResultMap = producerMaps[1];
85+
if (!producerInputMap.isProjectedPermutation() ||
86+
!producerResultMap.isPermutation()) {
87+
return failure();
88+
}
8389

8490
rewriter.modifyOpInPlace(attentionOp, [&]() {
8591
SmallVector<AffineMap> newIndexingMaps =
8692
attentionOp.getIndexingMapsArray();
87-
AffineMap inputMap = attentionOp.getMatchingIndexingMap(transposeOperand);
88-
SmallVector<AffineExpr> newExprs =
89-
applyPermutation(inputMap.getResults(), invPerm);
90-
AffineMap transposedMap =
91-
AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(),
92-
newExprs, rewriter.getContext());
93-
newIndexingMaps[inputIndex] = transposedMap;
93+
AffineMap consumerInputMap = attentionOp.getMatchingIndexingMap(operand);
94+
AffineMap composedMap =
95+
producerInputMap.compose(inversePermutation(producerResultMap));
96+
newIndexingMaps[inputIndex] = composedMap.compose(consumerInputMap);
9497
attentionOp.setIndexingMapsAttr(
9598
rewriter.getAffineMapArrayAttr(newIndexingMaps));
96-
attentionOp.setOperand(inputIndex, transposeOp.getDpsInputs()[0]);
99+
attentionOp.setOperand(inputIndex, producer.getDpsInputs()[0]);
97100
});
98101

99102
return success();

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,10 @@ void AttentionOpDetail::inferFromIndexingMaps(AffineMap qMap, AffineMap kMap,
3737
llvm::SmallDenseSet<int64_t> vSet = findPermutationsIndexingOperand(vMap);
3838
llvm::SmallDenseSet<int64_t> oSet = findPermutationsIndexingOperand(oMap);
3939

40-
// B = (Q & K & O) U (K & V & O)
40+
// B = (Q & V) U (K & O)
4141
llvm::SmallDenseSet<int64_t> b1Set = qSet;
42-
llvm::set_intersect(b1Set, kSet);
43-
llvm::set_intersect(b1Set, oSet);
42+
llvm::set_intersect(b1Set, vSet);
4443
llvm::SmallDenseSet<int64_t> b2Set = kSet;
45-
llvm::set_intersect(b2Set, vSet);
4644
llvm::set_intersect(b2Set, oSet);
4745
llvm::SmallDenseSet<int64_t> bSet = b1Set;
4846
llvm::set_union(bSet, b2Set);

compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ util.func public @fuse_generic_gather2(
208208
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
209209
// CHECK-NEXT: linalg.yield %[[RES4]] : f32
210210

211+
// -----
212+
211213
util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> {
212214
// Dequantize int-quantization of V
213215
%init_dequant = tensor.empty() : tensor<2x10x4096x64xf16>
@@ -258,3 +260,64 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf
258260
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
259261
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
260262
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]]
263+
264+
// -----
265+
266+
util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> {
267+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x8x128x?xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) {
268+
^bb0(%in: f16, %out: f16):
269+
linalg.yield %in : f16
270+
} -> tensor<4x8x4x128x?xf16>
271+
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) {
272+
^bb0(%arg7: f32):
273+
iree_linalg_ext.yield %arg7 : f32
274+
} -> tensor<4x8x4x?x32x128xf16>
275+
util.return %1 : tensor<4x8x4x?x32x128xf16>
276+
}
277+
278+
// CHECK-LABEL: func public @fuse_attention_with_broadcast
279+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
280+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
281+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
282+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]:
283+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
284+
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
285+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>,
286+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>,
287+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>,
288+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
289+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>
290+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
291+
// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
292+
// CHECK: util.return %[[ATTENTION]]
293+
294+
295+
// -----
296+
297+
util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x128xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> {
298+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4, d1)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x8x128xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) {
299+
^bb0(%in: f16, %out: f16):
300+
linalg.yield %in : f16
301+
} -> tensor<4x8x4x128x?xf16>
302+
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) {
303+
^bb0(%arg7: f32):
304+
iree_linalg_ext.yield %arg7 : f32
305+
} -> tensor<4x8x4x?x32x128xf16>
306+
util.return %1 : tensor<4x8x4x?x32x128xf16>
307+
}
308+
309+
// CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose
310+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
311+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
312+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
313+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]:
314+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
315+
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
316+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>,
317+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>,
318+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d5)>,
319+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
320+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>
321+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
322+
// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
323+
// CHECK: util.return %[[ATTENTION]]

0 commit comments

Comments
 (0)