Skip to content

Commit e6fb1e1

Browse files
authored
[Codegen] PV and QK matmul's must have same acc layout (iree-org#21729)
Fixes issue iree-org#21602 where vector distribute failed due to an unresolvable layout change in attention. Check that the 2 matmuls have the same accumulator layout. With this change, the reproducer in iree-org#21602 compiles down to a .vmfb. I have not checked numerics or looked at any performance benchmarks. --------- Signed-off-by: James Newling <[email protected]>
1 parent 9bb1a2b commit e6fb1e1

File tree

3 files changed

+59
-18
lines changed

3 files changed

+59
-18
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <cstdint>
1111

12+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
1213
#include "llvm/ADT/APInt.h"
1314
#include "llvm/ADT/Sequence.h"
1415
#include "llvm/Support/DebugLog.h"
@@ -666,6 +667,13 @@ struct ChainedMMAIntrinsics {
666667
bool canReuseAOutputForB;
667668
};
668669

670+
static bool matchLayout(IREE::GPU::MMASingleSubgroupLayout layoutA,
671+
IREE::GPU::MMASingleSubgroupLayout layoutB) {
672+
return (layoutA.element == layoutB.element) &&
673+
(layoutA.thread == layoutB.thread) &&
674+
(layoutA.tstrides == layoutB.tstrides);
675+
};
676+
669677
FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
670678
const GPUMatmulShapeType &qkMatmul, const GPUMatmulShapeType &pvMatmul,
671679
ArrayRef<GPUIntrinsicType> intrinsics,
@@ -677,28 +685,33 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
677685
qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 &&
678686
"unimplemented: multi M/N/K attention schedule");
679687

680-
std::vector<ChainedMMAIntrinsics> intrinsicPairs;
688+
SmallVector<uint64_t> qkViableIntrinsicIndices;
689+
SmallVector<uint64_t> pvViableIntrinsicIndices;
690+
for (const auto &[index, intrinsic] : llvm::enumerate(intrinsics)) {
691+
if (!failed(canTargetIntrinsic(qkMatmul, intrinsic, subgroupSize,
692+
canUpcastAcc, mustBeAligned))) {
693+
qkViableIntrinsicIndices.push_back(index);
694+
}
695+
if (!failed(canTargetIntrinsic(pvMatmul, intrinsic, subgroupSize,
696+
canUpcastAcc, mustBeAligned))) {
697+
pvViableIntrinsicIndices.push_back(index);
698+
}
699+
}
681700

682-
for (const GPUIntrinsicType &intrinsicA : intrinsics) {
683-
for (const GPUIntrinsicType &intrinsicB : intrinsics) {
684-
if (failed(canTargetIntrinsic(qkMatmul, intrinsicA, subgroupSize,
685-
canUpcastAcc, mustBeAligned))) {
701+
std::vector<ChainedMMAIntrinsics> intrinsicPairs;
702+
for (unsigned qkIndex : qkViableIntrinsicIndices) {
703+
for (unsigned pvIndex : pvViableIntrinsicIndices) {
704+
const GPUIntrinsicType &intrinsicA = intrinsics[qkIndex];
705+
const GPUIntrinsicType &intrinsicB = intrinsics[pvIndex];
706+
if (!matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
707+
IREE::GPU::MMAFragment::Acc),
708+
getSingleSubgroupLayout(intrinsicB.mmaKind,
709+
IREE::GPU::MMAFragment::Acc))) {
686710
continue;
687711
}
688712

689-
if (failed(canTargetIntrinsic(pvMatmul, intrinsicB, subgroupSize,
690-
canUpcastAcc, mustBeAligned))) {
691-
continue;
692-
}
693713
// Check if we can reuse the output of intrinsicA for lhs/rhs of
694714
// intrinsicB.
695-
auto matchLayout =
696-
[](IREE::GPU::MMASingleSubgroupLayout layoutA,
697-
IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
698-
return (layoutA.element == layoutB.element) &&
699-
(layoutA.thread == layoutB.thread) &&
700-
(layoutA.tstrides == layoutB.tstrides);
701-
};
702715
bool canReuseAOutForBLhs =
703716
matchLayout(getSingleSubgroupLayout(intrinsicA.mmaKind,
704717
IREE::GPU::MMAFragment::Acc),

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1818
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1919
#include "mlir/IR/Builders.h"
20-
#include "mlir/IR/BuiltinAttributes.h"
21-
#include "mlir/IR/BuiltinTypes.h"
2220

2321
namespace mlir::iree_compiler::IREE::GPU {
2422

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx950.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,33 @@ func.func @attention_large_head_dim_shared_mem() {
282282
// CHECK-SAME: subgroup_n_count = 1
283283
// CHECK-SAME: reduction = [0, 0, 64, 0]
284284
// CHECK-SAME: workgroup = [64, 0, 0, 64]
285+
286+
// -----
287+
288+
// The fix introduced for bug https://github.com/iree-org/iree/issues/21602 was
289+
// to constrain the MMA layout to be the same for the 2 matmuls inside
290+
// attention. Before this fix, the PV matmul used MFMA_F32_16x16x128_F8E4M3FN
291+
// and the QK matmul used MFMA_F32_32x32x64_F8E4M3FN. Vector distribution failed
292+
// to distribute these layouts to threads.
293+
294+
// CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {}>
295+
// CHECK-LABEL: func.func @attention_check_mma_accs_compatable
296+
297+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
298+
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
299+
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
300+
#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
301+
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
302+
func.func @attention_check_mma_accs_compatable(%arg0: f32, %arg1: tensor<960x4096x64xf8E4M3FN>, %arg2: tensor<960x4096x64xf8E4M3FN>, %arg3: tensor<960x4096x64xf8E4M3FN>, %arg4: tensor<960x4096x64xf32>, %arg5: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<960x4096x64xf32>>) {
303+
%0 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg1, %arg2, %arg3, %arg0 : tensor<960x4096x64xf8E4M3FN>, tensor<960x4096x64xf8E4M3FN>, tensor<960x4096x64xf8E4M3FN>, f32) outs(%arg4 : tensor<960x4096x64xf32>) {
304+
^bb0(%arg6: f32):
305+
iree_linalg_ext.yield %arg6 : f32
306+
} -> tensor<960x4096x64xf32>
307+
iree_tensor_ext.dispatch.tensor.store %0, %arg5, offsets = [0, 0, 0], sizes = [960, 4096, 64], strides = [1, 1, 1] : tensor<960x4096x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<960x4096x64xf32>>
308+
return
309+
}
310+
// CHECK: decomposition_config =
311+
// CHECK-SAME: attention_pv_matmul
312+
// CHECK-SAME: #iree_gpu.mma_layout<MFMA_F32_32x32x64_F8E4M3FN>
313+
// CHECK-SAME: attention_qk_matmul
314+
// CHECK-SAME: #iree_gpu.mma_layout<MFMA_F32_32x32x64_F8E4M3FN>

0 commit comments

Comments
 (0)