Skip to content

Commit 9eaa4ef

Browse files
authored
[tuner]: Add a utility function to query supported MMA intrinsics (#19124)
This PR aims to address the task listed in nod-ai/shark-ai#453: add a utility function (`QueryMMAIntrinsics`) to query supported MMA intrinsics. A new test pass `TestLLVMGPUQueryMMAPass` has been added to validate the correctness of this utility function, along with a corresponding test to ensure reliable functionality. TODO: The function will be exposed to both the C API and Python in a follow-up PR. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 4c0fd90 commit 9eaa4ef

File tree

9 files changed

+174
-0
lines changed

9 files changed

+174
-0
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ iree_compiler_cc_library(
112112
"ROCDLKernelConfig.cpp",
113113
"ROCDLLowerExecutableTarget.cpp",
114114
"ROCDLSelectLoweringStrategy.cpp",
115+
"TestLLVMGPUQueryMMAPass.cpp",
115116
"Verifiers.cpp",
116117
],
117118
hdrs = [

compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ iree_cc_library(
9797
"ROCDLKernelConfig.cpp"
9898
"ROCDLLowerExecutableTarget.cpp"
9999
"ROCDLSelectLoweringStrategy.cpp"
100+
"TestLLVMGPUQueryMMAPass.cpp"
100101
"Verifiers.cpp"
101102
DEPS
102103
::PassHeaders

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,9 @@ def TestLLVMGPUScalarizeMathOpPass :
163163
let summary = "Test pass for several legalization patterns.";
164164
}
165165

166+
def TestLLVMGPUQueryMMAPass :
167+
Pass<"iree-test-llvmgpu-query-mma", "ModuleOp"> {
168+
let summary = "Test pass for querying the supported mma intrinsic instructions.";
169+
}
170+
166171
#endif // IREE_CODEGEN_LLVMGPU_PASSES
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
8+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
9+
#include "mlir/Dialect/Func/IR/FuncOps.h"
10+
11+
#include "llvm/Support/Debug.h"
12+
13+
#define DEBUG_TYPE "iree-test-llvmgpu-query-mma"
14+
15+
namespace mlir::iree_compiler {
16+
17+
#define GEN_PASS_DEF_TESTLLVMGPUQUERYMMAPASS
18+
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
19+
20+
namespace {
21+
22+
struct TestLLVMGPUQueryMMAPass final
23+
: impl::TestLLVMGPUQueryMMAPassBase<TestLLVMGPUQueryMMAPass> {
24+
void runOnOperation() override {
25+
ModuleOp moduleOp = getOperation();
26+
llvm::SmallDenseMap<IREE::HAL::ExecutableVariantOp,
27+
SmallVector<IREE::GPU::MMAIntrinsic>>
28+
mmaMap = queryMMAIntrinsics(moduleOp);
29+
for (const auto &[op, mmaAttrs] : mmaMap) {
30+
llvm::outs() << "Executable Variant Name: "
31+
<< cast<IREE::HAL::ExecutableVariantOp>(*op).getName()
32+
<< "\n";
33+
llvm::outs() << "MMA Intrinsics: ";
34+
llvm::interleave(mmaAttrs, llvm::outs(), " ");
35+
llvm::outs() << "\n";
36+
}
37+
}
38+
};
39+
} // namespace
40+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ iree_lit_test_suite(
5555
"promote_matmul_to_fit_mma.mlir",
5656
"tensor_pad.mlir",
5757
"tensorcore_vectorization.mlir",
58+
"test_query_mma.mlir",
5859
"transform_dialect_bufferize.mlir",
5960
"transform_dialect_eliminate_gpu_barriers.mlir",
6061
"transform_dialect_hoist_allocs.mlir",

compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ iree_lit_test_suite(
5252
"rocdl_pipeline_test.mlir"
5353
"tensor_pad.mlir"
5454
"tensorcore_vectorization.mlir"
55+
"test_query_mma.mlir"
5556
"transform_dialect_bufferize.mlir"
5657
"transform_dialect_eliminate_gpu_barriers.mlir"
5758
"transform_dialect_hoist_allocs.mlir"
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// RUN: iree-opt --split-input-file --iree-test-llvmgpu-query-mma %s | FileCheck %s
2+
3+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
4+
{iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
5+
wgp = <compute = int32, storage = b32,
6+
subgroup = arithmetic, dot = dp4xi8toi32,
7+
mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>],
8+
subgroup_size_choices = [64], max_workgroup_sizes = [1024],
9+
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
10+
max_workgroup_counts = [2147483647]>>}>
11+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
12+
module {
13+
hal.executable private @main {
14+
hal.executable.variant public @main target(#executable_target_rocm_hsaco_fb) {
15+
hal.executable.export public @entry_point layout(#pipeline_layout)
16+
builtin.module {
17+
func.func @fn() {
18+
return
19+
}
20+
}
21+
}
22+
}
23+
}
24+
25+
// CHECK: Executable Variant Name
26+
// CHECK-SAME: main
27+
// CHECK: MMA Intrinsics
28+
// CHECK-SAME: MFMA_F32_16x16x4_F32
29+
// CHECK-SAME: MFMA_F32_16x16x16_F16
30+
// CHECK-LABEL: func.func @fn
31+
32+
// -----
33+
34+
#executable_target_rocm_hsaco_fb0 = #hal.executable.target<"rocm", "rocm-hsaco-fb",
35+
{iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
36+
wgp = <compute = int32, storage = b32,
37+
subgroup = arithmetic, dot = dp4xi8toi32,
38+
mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>],
39+
subgroup_size_choices = [64], max_workgroup_sizes = [1024],
40+
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
41+
max_workgroup_counts = [2147483647]>>}>
42+
#executable_target_rocm_hsaco_fb1 = #hal.executable.target<"rocm", "rocm-hsaco-fb",
43+
{iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
44+
wgp = <compute = int32, storage = b32,
45+
subgroup = arithmetic, dot = dp4xi8toi32,
46+
mma = [<MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>],
47+
subgroup_size_choices = [64], max_workgroup_sizes = [1024],
48+
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
49+
max_workgroup_counts = [2147483647]>>}>
50+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
51+
module {
52+
hal.executable private @main_0 {
53+
hal.executable.variant public @main_0 target(#executable_target_rocm_hsaco_fb0) {
54+
hal.executable.export public @entry_point_0 layout(#pipeline_layout)
55+
builtin.module {
56+
func.func @fn_0() {
57+
return
58+
}
59+
}
60+
}
61+
}
62+
hal.executable private @main_1 {
63+
hal.executable.variant public @main_1 target(#executable_target_rocm_hsaco_fb1) {
64+
hal.executable.export public @entry_point layout(#pipeline_layout)
65+
builtin.module {
66+
func.func @fn_1() {
67+
return
68+
}
69+
}
70+
}
71+
}
72+
}
73+
74+
// CHECK-DAG: main_0
75+
// CHECK-DAG: MMA Intrinsics: MFMA_F32_16x16x4_F32 MFMA_F32_16x16x16_F16
76+
// CHECK-DAG: main_1
77+
// CHECK-DAG: MMA Intrinsics: MFMA_F32_32x32x8_F16 MFMA_F32_16x16x16_BF16
78+
79+
// -----
80+
81+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
82+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
83+
module {
84+
hal.executable private @main {
85+
hal.executable.variant public @main target(#executable_target_rocm_hsaco_fb) {
86+
hal.executable.export public @entry_point layout(#pipeline_layout)
87+
builtin.module {
88+
func.func @fn_empty() {
89+
return
90+
}
91+
}
92+
}
93+
}
94+
}
95+
96+
// CHECK-NOT: Executable Variant Name
97+
// CHECK-NOT: MMA Intrinsics
98+
// CHECK-LABEL: func.func @fn

compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,4 +1028,22 @@ std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func) {
10281028
return std::nullopt;
10291029
}
10301030

1031+
llvm::SmallDenseMap<IREE::HAL::ExecutableVariantOp,
1032+
SmallVector<IREE::GPU::MMAIntrinsic>>
1033+
queryMMAIntrinsics(mlir::ModuleOp moduleOp) {
1034+
llvm::SmallDenseMap<IREE::HAL::ExecutableVariantOp,
1035+
SmallVector<IREE::GPU::MMAIntrinsic>>
1036+
mmaAttributesMap;
1037+
moduleOp.walk([&](IREE::HAL::ExecutableVariantOp executableOp) {
1038+
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(executableOp)) {
1039+
auto mmaIntrinsics = llvm::map_to_vector(
1040+
target.getWgp().getMma(), [](IREE::GPU::MMAAttr attr) {
1041+
return attr.getIntrinsic().getValue();
1042+
});
1043+
mmaAttributesMap[executableOp] = std::move(mmaIntrinsics);
1044+
}
1045+
});
1046+
return mmaAttributesMap;
1047+
}
1048+
10311049
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
99

1010
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
11+
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
1112
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
1213
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1314
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -206,6 +207,14 @@ IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op);
206207
/// Returns std::nullopt if none found.
207208
std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func);
208209

210+
/// Returns a map of supported MMA intrinsic instructions based on the
211+
/// GPU target descriptions in `moduleOp`. Each entry in the map associates
212+
/// an `IREE::HAL::ExecutableVariantOp` with a vector of
213+
/// `IREE::GPU::MMAIntrinsic` attributes.
214+
llvm::SmallDenseMap<IREE::HAL::ExecutableVariantOp,
215+
SmallVector<IREE::GPU::MMAIntrinsic>>
216+
queryMMAIntrinsics(mlir::ModuleOp moduleOp);
217+
209218
} // namespace mlir::iree_compiler
210219

211220
#endif // IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_

0 commit comments

Comments
 (0)