Skip to content

Commit ad7f009

Browse files
raikonenfnuGroverkss
authored andcommitted
[Codegen] Bubble up Transpose attention V and try fuse with others before attention (iree-org#19250)
Flash Attention transpose_V variant is significantly faster than the non-transpose_V variant. This is due to many matmul intrinsics being mmtb by default. Hence, doing FA transpose_V will allow for better/more contiguous reads from shared memory to register, improving the attention performance quite a bit. This PR exposes the attention_transposeV form by generating a linalg.transpose on the V during bubbling up of transpose S.T we can give the graph some opportunities to fuse the transpose-V to it's producer. I have also confirmed that if we do not find any producer, the transpose will indeed fuse back with the attenionOp. Hence worse case, we will get same perf as before this PR. Additionally, we modify elementwise op fusion to try fuse transpose with other ops before letting it get fused back into attention. --------- Signed-off-by: Stanley Winata <[email protected]>
1 parent f6e52d8 commit ad7f009

File tree

8 files changed

+314
-17
lines changed

8 files changed

+314
-17
lines changed

.github/workflows/pkgci_regression_test.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ jobs:
220220
--goldentime-rocm-unet-ms 419.0 \
221221
--goldentime-rocm-clip-ms 18.5 \
222222
--goldentime-rocm-vae-ms 337.0 \
223-
--goldendispatch-rocm-unet 1531 \
223+
--goldendispatch-rocm-unet 1602 \
224224
--goldendispatch-rocm-clip 1139 \
225225
--goldendispatch-rocm-vae 246 \
226226
--goldensize-rocm-unet-bytes 2280000 \
@@ -238,21 +238,21 @@ jobs:
238238
run: |
239239
source ${VENV_DIR}/bin/activate
240240
pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \
241-
--goldentime-rocm-e2e-ms 372.0 \
242-
--goldentime-rocm-unet-ms 95.0 \
241+
--goldentime-rocm-e2e-ms 330.0 \
242+
--goldentime-rocm-unet-ms 80.0 \
243243
--goldentime-rocm-clip-ms 15.5 \
244244
--goldentime-rocm-vae-ms 80.0 \
245-
--goldendispatch-rocm-unet 1531 \
245+
--goldendispatch-rocm-unet 1602 \
246246
--goldendispatch-rocm-clip 1139 \
247247
--goldendispatch-rocm-vae 246 \
248248
--goldensize-rocm-unet-bytes 2270000 \
249249
--goldensize-rocm-clip-bytes 860000 \
250250
--goldensize-rocm-vae-bytes 840000 \
251-
--goldentime-rocm-punet-int8-fp16-ms 55 \
252-
--goldendispatch-rocm-punet-int8-fp16 1284 \
251+
--goldentime-rocm-punet-int8-fp16-ms 53 \
252+
--goldendispatch-rocm-punet-int8-fp16 1424 \
253253
--goldensize-rocm-punet-int8-fp16-bytes 2560000 \
254-
--goldentime-rocm-punet-int8-fp8-ms 59 \
255-
--goldendispatch-rocm-punet-int8-fp8 1564 \
254+
--goldentime-rocm-punet-int8-fp8-ms 53 \
255+
--goldendispatch-rocm-punet-int8-fp8 1704 \
256256
--goldensize-rocm-punet-int8-fp8-bytes 2800000 \
257257
--rocm-chip gfx942 \
258258
--log-cli-level=info \

build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,41 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
208208
transform.yield %cont, %config : !transform.any_op, !transform.any_param
209209
}
210210

211+
212+
// Variant of matmul_like_Bx20x1024x64x1280_i8xi8xi32 from Transposed-V.
213+
transform.named_sequence @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
214+
-> (!transform.any_op, !transform.any_param) {
215+
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
216+
^bb0(%lhs: tensor<?x1024x1280xi8>, %rhs: tensor<20x64x1280xi8>, %out: tensor<?x20x64x1024xi32>):
217+
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
218+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>,
219+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
220+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
221+
ins(%lhs, %rhs : tensor<?x1024x1280xi8>, tensor<20x64x1280xi8>)
222+
outs(%out : tensor<?x20x64x1024xi32>) {
223+
^bb0(%in: i8, %in_0: i8, %acc: i32):
224+
%18 = arith.extsi %in : i8 to i32
225+
%19 = arith.extsi %in_0 : i8 to i32
226+
%20 = arith.muli %18, %19 : i32
227+
%21 = arith.addi %acc, %20 : i32
228+
linalg.yield %21 : i32
229+
} -> tensor<?x20x64x1024xi32>
230+
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
231+
%config = transform.param.constant #iree_codegen.compilation_info<
232+
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
233+
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
234+
subgroup_m_count = 2, subgroup_n_count = 2,
235+
reduction = [0, 0, 0, 0, 128],
236+
workgroup = [1, 1, 160, 64, 0]}>,
237+
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
238+
workgroup_size = [256, 1, 1] subgroup_size = 64,
239+
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true,
240+
reorder_workgroups_strategy = <Transpose>>
241+
}>
242+
> -> !transform.any_param
243+
transform.yield %cont, %config : !transform.any_op, !transform.any_param
244+
}
245+
211246
transform.named_sequence @match_matmul_like_Bx20x64x64x2048_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
212247
-> (!transform.any_op, !transform.any_param) {
213248
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
@@ -239,6 +274,38 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
239274
transform.yield %cont, %config : !transform.any_op, !transform.any_param
240275
}
241276

277+
// Variant of matmul_like_Bx20x64x64x2048_i8xi8xi32 from Transposed-V.
278+
transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
279+
-> (!transform.any_op, !transform.any_param) {
280+
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
281+
^bb0(%lhs: tensor<?x64x2048xi8>, %rhs: tensor<20x64x2048xi8>, %out: tensor<?x20x64x64xi32>):
282+
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
283+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>,
284+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
285+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
286+
ins(%lhs, %rhs : tensor<?x64x2048xi8>, tensor<20x64x2048xi8>)
287+
outs(%out : tensor<?x20x64x64xi32>) {
288+
^bb0(%in: i8, %in_0: i8, %acc: i32):
289+
%18 = arith.extsi %in : i8 to i32
290+
%19 = arith.extsi %in_0 : i8 to i32
291+
%20 = arith.muli %18, %19 : i32
292+
%21 = arith.addi %acc, %20 : i32
293+
linalg.yield %21 : i32
294+
} -> tensor<?x20x64x64xi32>
295+
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
296+
%config = transform.param.constant #iree_codegen.compilation_info<
297+
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
298+
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
299+
subgroup_m_count = 2, subgroup_n_count = 1,
300+
reduction = [0, 0, 0, 0, 128],
301+
workgroup = [1, 1, 320, 32, 0]}>,
302+
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
303+
workgroup_size = [128, 1, 1] subgroup_size = 64,
304+
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
305+
> -> !transform.any_param
306+
transform.yield %cont, %config : !transform.any_op, !transform.any_param
307+
}
308+
242309
transform.named_sequence @match_matmul_like_Bx10x4096x64x640_i8xi8xi32(%cont: !transform.any_op {transform.readonly})
243310
-> (!transform.any_op, !transform.any_param) {
244311
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {
@@ -302,6 +369,10 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
302369
, @match_matmul_like_Bx10x4096x64x640_i8xi8xi32 -> @apply_op_config
303370
, @match_matmul_like_Bx20x64x64x2048_i8xi8xi32 -> @apply_op_config
304371

372+
// Transpose-V generated contraction.
373+
, @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32 -> @apply_op_config
374+
, @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32 -> @apply_op_config
375+
305376
// TUNING_MATCH_END DO NOT REMOVE
306377
: (!transform.any_op) -> (!transform.any_op)
307378
transform.yield

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ void populateFuseLinalgExtOpsWithTransposes(
1919
RewritePatternSet &patterns,
2020
const linalg::ControlFusionFn &controlFusionFn);
2121

22+
/// Bubble up transpose-like ops from LinalgExt ops (only `AttentionOp`
23+
/// supported).
24+
void populateBubbleTransposeFromLinalgExtOps(
25+
RewritePatternSet &patterns,
26+
const linalg::ControlFusionFn &controlFusionFn);
27+
2228
/// Helper struct to hold the results of collapsing an operation.
2329
struct CollapseResult {
2430
SmallVector<Value> results;

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
88
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
99
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
10+
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
1011
#include "llvm/ADT/STLExtras.h"
1112
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1213
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -101,6 +102,103 @@ struct FuseTransposeWithAttentionOp final
101102
private:
102103
linalg::ControlFusionFn controlFn;
103104
};
105+
106+
// Bubbles transpose-V out of attention to expose the more performant
107+
// attention-transposeV.
108+
struct BubbleTransposeVFromAttentionOp
109+
: public OpRewritePattern<LinalgExt::AttentionOp> {
110+
BubbleTransposeVFromAttentionOp(MLIRContext *context,
111+
linalg::ControlFusionFn controlFn,
112+
PatternBenefit benefit = 1)
113+
: OpRewritePattern<LinalgExt::AttentionOp>(context, benefit),
114+
controlFn(controlFn) {}
115+
116+
LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp,
117+
PatternRewriter &rewriter) const override {
118+
// Only checking for V because we are only bubbling transpose-V.
119+
OpOperand *valueOpOperand = &attentionOp.getValueMutable();
120+
if (controlFn && !controlFn(valueOpOperand)) {
121+
return rewriter.notifyMatchFailure(
122+
attentionOp, "Expected attentionOp and producer of V to be non-null "
123+
"and outside dispatch.");
124+
}
125+
// Extract Attention indexing information.
126+
AffineMap qMap = attentionOp.getQueryMap();
127+
AffineMap kMap = attentionOp.getKeyMap();
128+
AffineMap vMap = attentionOp.getValueMap();
129+
AffineMap oMap = attentionOp.getOutputMap();
130+
FailureOr<AttentionOpDetail> maybeOpInfo =
131+
AttentionOpDetail::get(qMap, kMap, vMap, oMap);
132+
if (failed(maybeOpInfo)) {
133+
return failure();
134+
}
135+
136+
// Only handle single dim for K2 and N for now.
137+
if (maybeOpInfo->getK2Dims().size() != 1 ||
138+
maybeOpInfo->getNDims().size() != 1) {
139+
return failure();
140+
}
141+
// Check that V has standard map/non transposed V.
142+
AffineExpr k2Dim =
143+
rewriter.getAffineDimExpr(maybeOpInfo->getK2Dims().back());
144+
AffineExpr nDim = rewriter.getAffineDimExpr(maybeOpInfo->getNDims().back());
145+
int64_t vRank = vMap.getNumResults();
146+
// TODO: This check is quite conservative, in the future we should simply
147+
// do vMap.getResultPosition(k2Dim) > vMap.getResultPosition(nDim).
148+
if (vMap.getResult(vRank - 1) != nDim ||
149+
vMap.getResult(vRank - 2) != k2Dim) {
150+
return failure();
151+
}
152+
153+
// Get dimension positions to prepare for transpose.
154+
std::optional<int64_t> maybeK2Pos = vMap.getResultPosition(k2Dim);
155+
std::optional<int64_t> maybeNPos = vMap.getResultPosition(nDim);
156+
assert(maybeK2Pos.has_value() && maybeNPos.has_value() &&
157+
"Expected K2 dim and N dim to be in V-map.");
158+
int64_t k2Pos = maybeK2Pos.value();
159+
int64_t nPos = maybeNPos.value();
160+
SmallVector<int64_t> perm = llvm::to_vector(llvm::seq<int64_t>(0, vRank));
161+
std::swap(perm[k2Pos], perm[nPos]);
162+
163+
// Expose transposeOp for V.
164+
Location loc = attentionOp.getLoc();
165+
Value value = attentionOp.getValue();
166+
auto valueType = dyn_cast<ShapedType>(value.getType());
167+
auto valueElType = valueType.getElementType();
168+
SmallVector<OpFoldResult> transVShape =
169+
tensor::getMixedSizes(rewriter, loc, value);
170+
applyPermutationToVector(transVShape, perm);
171+
Value initTransV =
172+
rewriter.create<tensor::EmptyOp>(loc, transVShape, valueElType)
173+
.getResult();
174+
Value transposeV =
175+
rewriter.create<linalg::TransposeOp>(loc, value, initTransV, perm)
176+
->getResult(0);
177+
178+
// Generate transpose V map.
179+
SmallVector<AffineExpr> newExprs =
180+
applyPermutation(vMap.getResults(), perm);
181+
AffineMap transposedVMap =
182+
AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), newExprs,
183+
rewriter.getContext());
184+
185+
// Modify attention to have transposed V inputs and mapping.
186+
int64_t valueIndex = valueOpOperand->getOperandNumber();
187+
rewriter.modifyOpInPlace(attentionOp, [&]() {
188+
SmallVector<AffineMap> newIndexingMaps =
189+
attentionOp.getIndexingMapsArray();
190+
newIndexingMaps[valueIndex] = transposedVMap;
191+
attentionOp.setIndexingMapsAttr(
192+
rewriter.getAffineMapArrayAttr(newIndexingMaps));
193+
attentionOp.setOperand(valueIndex, transposeV);
194+
});
195+
return success();
196+
}
197+
198+
private:
199+
linalg::ControlFusionFn controlFn;
200+
};
201+
104202
} // namespace
105203

106204
void populateFuseLinalgExtOpsWithTransposes(
@@ -110,4 +208,11 @@ void populateFuseLinalgExtOpsWithTransposes(
110208
controlFusionFn);
111209
}
112210

211+
void populateBubbleTransposeFromLinalgExtOps(
212+
RewritePatternSet &patterns,
213+
const linalg::ControlFusionFn &controlFusionFn) {
214+
patterns.add<BubbleTransposeVFromAttentionOp>(patterns.getContext(),
215+
controlFusionFn);
216+
}
217+
113218
} // namespace mlir::iree_compiler::IREE::LinalgExt

compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
104104
void ElementwiseOpFusionPass::runOnOperation() {
105105
MLIRContext *context = &getContext();
106106

107-
RewritePatternSet fusionPatterns(context);
108107
// Only fuse operations where all uses of the producer are generic
109108
// operations. If an operation is used in a named op, it will be computed
110109
// anyway, so the consumers can just use that value.
@@ -135,24 +134,35 @@ void ElementwiseOpFusionPass::runOnOperation() {
135134
return areFusableAsElementwiseOps(context, fusedOperand,
136135
fuseMultiReduction);
137136
};
138-
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
137+
138+
RewritePatternSet linalgFusionPatterns(context);
139+
linalg::populateElementwiseOpsFusionPatterns(linalgFusionPatterns,
139140
fuseElementwiseOpsControlFn);
140141

142+
GreedyRewriteConfig rewriteConfig;
143+
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
144+
if (failed(applyPatternsAndFoldGreedily(
145+
getOperation(), std::move(linalgFusionPatterns), rewriteConfig))) {
146+
getOperation()->emitOpError(
147+
"Failed to fuse elementwise ops with upstream patterns.");
148+
return signalPassFailure();
149+
}
150+
151+
// Try fuse with linalgExt patterns.
141152
linalg::ControlFusionFn foldTransposeControlFn = [](OpOperand *fusedOperand) {
142153
Operation *producer = fusedOperand->get().getDefiningOp();
143154
Operation *consumer = fusedOperand->getOwner();
144155

145156
return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer});
146157
};
158+
RewritePatternSet linalgExtFusionPatterns(context);
147159
IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes(
148-
fusionPatterns, foldTransposeControlFn);
149-
fusionPatterns.insert<GatherFusionPattern>(context);
150-
151-
GreedyRewriteConfig rewriteConfig;
152-
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
160+
linalgExtFusionPatterns, foldTransposeControlFn);
161+
linalgExtFusionPatterns.insert<GatherFusionPattern>(context);
153162
if (failed(applyPatternsAndFoldGreedily(
154-
getOperation(), std::move(fusionPatterns), rewriteConfig))) {
155-
getOperation()->emitOpError("Failed to perform elementwise operations");
163+
getOperation(), std::move(linalgExtFusionPatterns), rewriteConfig))) {
164+
getOperation()->emitOpError(
165+
"Failed to fuse elementwise ops with linalgExt patterns.");
156166
return signalPassFailure();
157167
}
158168
}

compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,54 @@ util.func public @fuse_generic_gather2(
207207
// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32
208208
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
209209
// CHECK-NEXT: linalg.yield %[[RES4]] : f32
210+
211+
util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> {
212+
// Dequantize int-quantization of V
213+
%init_dequant = tensor.empty() : tensor<2x10x4096x64xf16>
214+
%v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%quantized_v, %quant_offset, %quant_scale : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) outs(%init_dequant : tensor<2x10x4096x64xf16>) {
215+
^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16):
216+
%19 = arith.addi %in, %in_0 : i32
217+
%20 = arith.sitofp %19 : i32 to f32
218+
%21 = arith.mulf %20, %in_1 : f32
219+
%22 = arith.truncf %21 : f32 to f16
220+
linalg.yield %22 : f16
221+
} -> tensor<2x10x4096x64xf16>
222+
223+
// Transpose-V
224+
%init_transpose = tensor.empty() : tensor<2x10x64x4096xf16>
225+
%transpose_v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%v : tensor<2x10x4096x64xf16>) outs(%init_transpose : tensor<2x10x64x4096xf16>) {
226+
^bb0(%in: f16, %out: f16):
227+
linalg.yield %in : f16
228+
} -> tensor<2x10x64x4096xf16>
229+
230+
// Attention-Transpose-V
231+
%init_attention = tensor.empty() : tensor<2x10x4096x64xf16>
232+
%attention = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%q, %k, %transpose_v, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16) outs(%init_attention : tensor<2x10x4096x64xf16>) {
233+
^bb0(%score: f16):
234+
iree_linalg_ext.yield %score: f16
235+
} -> tensor<2x10x4096x64xf16>
236+
util.return %attention : tensor<2x10x4096x64xf16>
237+
}
238+
239+
// CHECK-LABEL: util.func public @fuse_transpose_attention_to_producer
240+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
241+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
242+
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
243+
// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor
244+
// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
245+
// CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: f16
246+
// CHECK: %[[DEQUANT_V:.+]] = linalg.generic
247+
// CHECK-SAME: indexing_maps =
248+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
249+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)>
250+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)>
251+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>]
252+
// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]]
253+
// CHECK: %[[RESULT:.+]] = iree_linalg_ext.attention
254+
// CHECK-SAME: indexing_maps =
255+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
256+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
257+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
258+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
259+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
260+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]]

0 commit comments

Comments
 (0)