Skip to content

Commit 72d98bc

Browse files
authored
GPU ukernel lowering config for data-tiled multi_mma, and a simple ukernel. (iree-org#19504)
This PR adds the KernelConfig logic to generate a lowering_config selecting a ukernel for multi_mma. In order to be able to test it, this PR also adds a very simple `multi_mma` ukernel, but it isn't actually exercised yet, other than successfully compiling to bitcode. The compiler logic only cares about the existence of the resulting bitcode file. The actual lowering to ukernel op will come in the next PR. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent a31da1f commit 72d98bc

File tree

12 files changed

+177
-43
lines changed

12 files changed

+177
-43
lines changed

compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ argmax_types = [
4646
[iree_amdgpu_bitcode_library(
4747
name = "iree_uk_amdgpu_argmax_%s_%s" % (type, gpu_arch),
4848
srcs = [
49-
"iree_uk_amdgpu_argmax_%s.c" % type,
5049
"common.h",
50+
"iree_uk_amdgpu_argmax_%s.c" % type,
5151
],
5252
out = "iree_uk_amdgpu_argmax_%s.%s.bc" % (type, gpu_arch),
5353
gpu_arch = gpu_arch,
@@ -59,9 +59,21 @@ argmax_bc_files = [
5959
for gpu_arch in gpu_archs
6060
]
6161

62+
iree_amdgpu_bitcode_library(
63+
name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942",
64+
srcs = [
65+
"common.h",
66+
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c",
67+
],
68+
out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc",
69+
gpu_arch = "gfx942",
70+
)
71+
6272
iree_c_embed_data(
6373
name = "iree_uk_amdgpu_bitcode",
64-
srcs = argmax_bc_files,
74+
srcs = argmax_bc_files + [
75+
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc",
76+
],
6577
c_file_output = "iree_uk_amdgpu_bitcode.c",
6678
flatten = True,
6779
h_file_output = "iree_uk_amdgpu_bitcode.h",

compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ iree_amdgpu_bitcode_library(
206206
"iree_uk_amdgpu_argmax_f32i64.gfx1100.bc"
207207
)
208208

209+
iree_amdgpu_bitcode_library(
210+
NAME
211+
iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942
212+
GPU_ARCH
213+
gfx942
214+
SRCS
215+
"common.h"
216+
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c"
217+
OUT
218+
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
219+
)
220+
209221
iree_c_embed_data(
210222
NAME
211223
iree_uk_amdgpu_bitcode
@@ -226,6 +238,7 @@ iree_c_embed_data(
226238
"iree_uk_amdgpu_argmax_f32i64.gfx1100.bc"
227239
"iree_uk_amdgpu_argmax_f32i64.gfx90a.bc"
228240
"iree_uk_amdgpu_argmax_f32i64.gfx942.bc"
241+
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
229242
C_FILE_OUTPUT
230243
"iree_uk_amdgpu_bitcode.c"
231244
H_FILE_OUTPUT

compiler/plugins/target/ROCM/builtins/ukernel/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ typedef __UINT64_TYPE__ uint64_t;
5757
#define FLT_MIN __FLT_MIN__
5858
#define FLT_MAX __FLT_MAX__
5959

60+
//===----------------------------------------------------------------------===//
61+
// Vector typedefs
62+
//===----------------------------------------------------------------------===//
63+
64+
typedef __attribute__((__vector_size__(8 * 2))) int64_t int64x2_t;
65+
typedef __attribute__((__vector_size__(4 * 4))) int32_t int32x4_t;
66+
6067
//===----------------------------------------------------------------------===//
6168
// Declarations for Clangd, which may be slightly older than actual clang.
6269
// Drop these as clangd versions used in practice gain these builtins.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
8+
9+
// Very naive kernel. TODO(bjacob):
10+
// 1. Shared memory: can't allocate it within the microkernel (which is just a
11+
// helper device function, not the actual amdgpu_kernel). Need to get it
12+
// passed down here as a `T [[clang::address_space(3)]] *` parameter.
13+
// 2. Better scheduling via either barrier intrinsics or inline assemby.
14+
// 3. Subgroups1x4 being asymmetric is a historical accident... should be 2x2.
15+
[[clang::always_inline]] void
16+
iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4(
17+
const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer,
18+
int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int64_t k_size) {
19+
int tid = __builtin_amdgcn_workitem_id_x();
20+
21+
// Load existing accumulators.
22+
int32x4_t acc[8][2] = {{0}};
23+
int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset);
24+
for (int i = 0; i < 8; ++i) {
25+
for (int j = 0; j < 2; ++j) {
26+
acc[i][j] = c_global[256 * (2 * i + j) + tid];
27+
}
28+
}
29+
30+
// Arithmetic loop.
31+
const int64x2_t *a_global =
32+
(const int64x2_t *)(a_buffer + a_offset) + (tid % 64);
33+
const int64x2_t *b_global = (const int64x2_t *)(b_buffer + b_offset) + tid;
34+
for (int k_outer = 0; k_outer < k_size; ++k_outer) {
35+
for (int i = 0; i < 8; ++i) {
36+
for (int j = 0; j < 2; ++j) {
37+
for (int k = 0; k < 2; ++k) {
38+
acc[i][j] = __builtin_amdgcn_mfma_i32_16x16x32_i8(
39+
a_global[64 * i][k], b_global[256 * j][k], acc[i][j], 0, 0, 0);
40+
}
41+
}
42+
}
43+
a_global += 512;
44+
b_global += 512;
45+
}
46+
47+
// Store accumulators.
48+
for (int i = 0; i < 8; ++i) {
49+
for (int j = 0; j < 2; ++j) {
50+
c_global[256 * (2 * i + j) + tid] = acc[i][j];
51+
}
52+
}
53+
}

compiler/plugins/target/ROCM/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ iree_lit_test_suite(
1717
srcs = [
1818
"config_ukernel_argmax_gfx908.mlir",
1919
"config_ukernel_argmax_gfx942.mlir",
20+
"config_ukernel_multi_mma_gfx942.mlir",
2021
"default_tuning_specs_amdgpu.mlir",
2122
"lowering_strategy_from_tuning_spec.mlir",
2223
"ukernel_pipeline_transform.mlir",

compiler/plugins/target/ROCM/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ iree_lit_test_suite(
1616
SRCS
1717
"config_ukernel_argmax_gfx908.mlir"
1818
"config_ukernel_argmax_gfx942.mlir"
19+
"config_ukernel_multi_mma_gfx942.mlir"
1920
"default_tuning_specs_amdgpu.mlir"
2021
"lowering_strategy_from_tuning_spec.mlir"
2122
"ukernel_pipeline_transform.mlir"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s
2+
3+
func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x4x16x2x8xi8>,
4+
%b : tensor<1x2x4x2x4x16x2x8xi8>,
5+
%c : tensor<1x1x8x4x2x4x16x4xi32>)
6+
-> tensor<1x1x8x4x2x4x16x4xi32> attributes {
7+
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "multi_mma"}>
8+
} {
9+
%d = iree_gpu.multi_mma %a, %b, %c {indexing_maps = [
10+
affine_map<(d0, d1, d2) -> (d0, d2)>,
11+
affine_map<(d0, d1, d2) -> (d1, d2)>,
12+
affine_map<(d0, d1, d2) -> (d0, d1)>
13+
], iterator_types = [
14+
#iree_gpu.iterator_type<parallel>,
15+
#iree_gpu.iterator_type<parallel>,
16+
#iree_gpu.iterator_type<reduction>
17+
], kind = #iree_gpu.data_tiled_mma_layout<
18+
intrinsic = MFMA_I32_16x16x32_I8,
19+
unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 2
20+
>} : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8> into tensor<1x1x8x4x2x4x16x4xi32>
21+
return %d : tensor<1x1x8x4x2x4x16x4xi32>
22+
}
23+
24+
// CHECK-LABEL: @multi_mma_mfma_i32_16x16x32_i8
25+
// CHECK: iree_gpu.multi_mma
26+
// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
27+
// CHECK-NOT: promote_operands
28+
// CHECK-SAME: reduction = [0, 0, 0]
29+
// CHECK-SAME: #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4"

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ namespace mlir::iree_compiler::IREE::GPU {
3333

3434
constexpr int64_t kCacheLineSizeBits = 128 * 8;
3535

36-
LogicalResult
37-
setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target,
38-
mlir::FunctionOpInterface entryPoint,
39-
Operation *op) {
36+
LogicalResult setDataTiledMultiMmaLoweringConfig(
37+
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
38+
Operation *op, IREE::GPU::UKernelConfigAttr ukernelConfig) {
4039
auto multiMmaOp = dyn_cast<IREE::GPU::MultiMmaOp>(op);
4140
if (!multiMmaOp) {
4241
return failure();
@@ -70,7 +69,7 @@ setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target,
7069
SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
7170
for (int64_t kDim : contractionDims.k) {
7271
workgroupTileSizes[kDim] = 0;
73-
reductionTileSizes[kDim] = 1;
72+
reductionTileSizes[kDim] = ukernelConfig ? 0 : 1;
7473
}
7574

7675
// Set tile sizes.
@@ -81,8 +80,16 @@ setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target,
8180
b.getI64ArrayAttr(workgroupTileSizes));
8281
attrs.emplace_back(b.getStringAttr("reduction"),
8382
b.getI64ArrayAttr(reductionTileSizes));
84-
// Promote operands to use shared memory for LHS and RHS.
85-
GPU::setPromotedOperandList(context, attrs, {0, 1});
83+
if (ukernelConfig) {
84+
attrs.emplace_back(b.getStringAttr("ukernel"), ukernelConfig);
85+
} else {
86+
// Promote operands to use shared memory for LHS and RHS.
87+
// Don't do that with ukernels: their untiled reduction dimension is too
88+
// large to fit in shared memory, so they just want global memory and they
89+
// will take care of moving small chunks at a time into a shared memory
90+
// operand that will be created together with the ukernel op.
91+
GPU::setPromotedOperandList(context, attrs, {0, 1});
92+
}
8693
auto configDict = b.getDictionaryAttr(attrs);
8794
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
8895

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ namespace mlir::iree_compiler::IREE::GPU {
1616

1717
/// Helper for setting up a data tiled multi_mma config based on the specified
1818
/// target.
19-
LogicalResult
20-
setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target,
21-
mlir::FunctionOpInterface entryPoint,
22-
Operation *op);
19+
LogicalResult setDataTiledMultiMmaLoweringConfig(
20+
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
21+
Operation *op, IREE::GPU::UKernelConfigAttr ukernelConfig);
2322

2423
/// Helper for setting up a convolution config using IGEMM based on the
2524
/// specified target.

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,15 +2099,9 @@ static LogicalResult setTransposeConfig(mlir::FunctionOpInterface entryPoint,
20992099
/// Set the configuration for argmax when ukernels are enabled.
21002100
/// Distribute all parallel dim across different workgroups, and only use single
21012101
/// subgroup per workgroup.
2102-
static LogicalResult
2103-
setArgmaxUkernelConfig(IREE::GPU::TargetAttr target,
2104-
mlir::FunctionOpInterface entryPoint,
2105-
linalg::GenericOp op) {
2106-
IREE::GPU::UKernelConfigAttr ukernelConfig = selectUKernel(op);
2107-
if (!ukernelConfig) {
2108-
return failure();
2109-
}
2110-
2102+
static LogicalResult setArgmaxUkernelConfig(
2103+
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
2104+
linalg::GenericOp op, IREE::GPU::UKernelConfigAttr ukernelConfig) {
21112105
SmallVector<unsigned> parallelDims;
21122106
SmallVector<unsigned> reductionDims;
21132107
op.getParallelDims(parallelDims);
@@ -2170,15 +2164,6 @@ setArgmaxUkernelConfig(IREE::GPU::TargetAttr target,
21702164
return success();
21712165
}
21722166

2173-
/// Make UKernels take the LLVMGPUDefault lowering pipeline.
2174-
static LogicalResult
2175-
setUKernelConfig(mlir::FunctionOpInterface entryPoint,
2176-
IREE::Codegen::UKernelOpInterface ukernelOp) {
2177-
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
2178-
entryPoint->getContext(), CodeGenPipeline::LLVMGPUDefault);
2179-
return setTranslationInfo(entryPoint, translationInfo);
2180-
}
2181-
21822167
/// Decides the tiling and distribution parameters for one convolution
21832168
/// dimension. Returns true if we can succesfully deduce.
21842169
///
@@ -2358,13 +2343,14 @@ static LogicalResult setConvolutionConfig(
23582343
static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
23592344
mlir::FunctionOpInterface entryPointFn,
23602345
Operation *computeOp) {
2346+
IREE::GPU::UKernelConfigAttr ukernelConfig = selectUKernel(computeOp);
23612347
LLVM_DEBUG({
23622348
DBGS() << "Selecting root config for: ";
23632349
computeOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
23642350
llvm::dbgs() << "\n";
23652351
});
23662352
if (succeeded(setDataTiledMultiMmaLoweringConfig(target, entryPointFn,
2367-
computeOp))) {
2353+
computeOp, ukernelConfig))) {
23682354
LDBG("Tile and fuse data tiled multi_mma config");
23692355
return success();
23702356
}
@@ -2410,8 +2396,9 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
24102396
if (genericOp && succeeded(setTransposeConfig(entryPointFn, genericOp))) {
24112397
LDBG("Transpose Config");
24122398
return success();
2413-
} else if (genericOp && succeeded(setArgmaxUkernelConfig(
2414-
target, entryPointFn, genericOp))) {
2399+
} else if (genericOp && ukernelConfig &&
2400+
succeeded(setArgmaxUkernelConfig(target, entryPointFn, genericOp,
2401+
ukernelConfig))) {
24152402
LDBG("Argmax Ukernel Config");
24162403
return success();
24172404
}
@@ -2435,10 +2422,6 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
24352422
LDBG("Pack Config");
24362423
return setPackConfig(target, entryPointFn, packOp);
24372424
})
2438-
.Case<IREE::Codegen::UKernelOpInterface>([&](auto ukernelOp) {
2439-
LDBG("Ukernel Config");
2440-
return setUKernelConfig(entryPointFn, ukernelOp);
2441-
})
24422425
.Case<IREE::LinalgExt::CustomOp>([&](auto customOp) {
24432426
LDBG("CustomOp Config");
24442427
return setDefaultCustomOpLoweringConfig(entryPointFn, customOp,

0 commit comments

Comments
 (0)