Skip to content

Commit e1ce3fa

Browse files
authored
[tuner]: add c/python binding for querying mma intrinsic (#19218)
After this PR: #19199 add Python bindings to these two utility functions to querying mma intrinsic instructions from input module. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 1654ce6 commit e1ce3fa

File tree

10 files changed

+114
-2
lines changed

10 files changed

+114
-2
lines changed

compiler/bindings/c/iree/compiler/dialects/iree_codegen.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet(
6666
MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters
6767
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr);
6868

69+
MLIR_CAPI_EXPORTED void
70+
ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps,
71+
MlirOperation *executableOps);
72+
73+
MLIR_CAPI_EXPORTED void ireeCodegenQueryMMAIntrinsics(MlirOperation op,
74+
size_t *numIntrinsics,
75+
uint32_t *mmaIntrinsics);
76+
6977
#ifdef __cplusplus
7078
}
7179
#endif

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,33 @@ static const char *kGpuModuleImportPath =
2121
namespace py = pybind11;
2222
using namespace mlir::python::adaptors;
2323

24+
static std::vector<MlirOperation>
25+
ireeCodegenGetExecutableVariantOpsBinding(MlirModule module) {
26+
size_t numOps = 0;
27+
ireeCodegenGetExecutableVariantOps(module, &numOps, nullptr);
28+
std::vector<MlirOperation> ops(numOps);
29+
ireeCodegenGetExecutableVariantOps(module, &numOps, ops.data());
30+
31+
return ops;
32+
}
33+
34+
static std::vector<py::object>
35+
ireeCodegenQueryMMAIntrinsicsBinding(MlirOperation op) {
36+
size_t numMMAs = 0;
37+
ireeCodegenQueryMMAIntrinsics(op, &numMMAs, nullptr);
38+
std::vector<uint32_t> mmaIntrinsics(numMMAs);
39+
ireeCodegenQueryMMAIntrinsics(op, &numMMAs, mmaIntrinsics.data());
40+
41+
py::object mmaIntrinsicEnum =
42+
py::module_::import(kGpuModuleImportPath).attr("MMAIntrinsic");
43+
std::vector<py::object> mmaList(numMMAs);
44+
for (size_t i = 0; i < numMMAs; ++i) {
45+
mmaList[i] = mmaIntrinsicEnum(mmaIntrinsics[i]);
46+
}
47+
48+
return mmaList;
49+
}
50+
2451
PYBIND11_MODULE(_ireeCompilerDialects, m) {
2552
m.doc() = "iree-compiler dialects python extension";
2653

@@ -326,4 +353,22 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
326353
"Gets an #iree_gpu.lowering_config from parameters.")
327354
.def_property_readonly("attributes",
328355
ireeGPULoweringConfigAttrGetAttributes);
356+
357+
//===-------------------------------------------------------------------===//
358+
// Binding to utility function getExecutableVariantOps
359+
//===-------------------------------------------------------------------===//
360+
361+
iree_codegen_module.def(
362+
"get_executable_variant_ops", &ireeCodegenGetExecutableVariantOpsBinding,
363+
"Gets the executable variant operations from a module.",
364+
py::arg("module"));
365+
366+
//===-------------------------------------------------------------------===//
367+
// Binding to utility function queryMMAIntrinsics
368+
//===-------------------------------------------------------------------===//
369+
370+
iree_codegen_module.def(
371+
"query_mma_intrinsics", &ireeCodegenQueryMMAIntrinsicsBinding,
372+
"Queries the MMA intrinsics from an executable variant op.",
373+
py::arg("op"));
329374
}

compiler/src/iree/compiler/API/Internal/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ iree_compiler_cc_library(
137137
deps = [
138138
"//compiler/bindings/c:headers",
139139
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
140+
"//compiler/src/iree/compiler/Codegen/Utils",
140141
"@llvm-project//mlir:CAPIIR",
141142
"@llvm-project//mlir:CAPIIRHeaders",
142143
"@llvm-project//mlir:IR",

compiler/src/iree/compiler/API/Internal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ iree_cc_library(
116116
MLIRCAPIIR
117117
MLIRIR
118118
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
119+
iree::compiler::Codegen::Utils
119120
iree::compiler::bindings::c::headers
120121
PUBLIC
121122
)

compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <type_traits>
1111
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1212
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
13+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1314
#include "iree/compiler/dialects/iree_codegen.h"
1415
#include "mlir-c/BuiltinAttributes.h"
1516
#include "mlir-c/IR.h"
@@ -24,6 +25,8 @@ using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
2425
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
2526
using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
2627
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
28+
using mlir::iree_compiler::IREE::GPU::MMAIntrinsic;
29+
using mlir::iree_compiler::IREE::HAL::ExecutableVariantOp;
2730

2831
bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
2932
MlirAttribute attr) {
@@ -149,3 +152,49 @@ ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) {
149152
parameters.translationInfo = wrap(compilationInfo.getTranslationInfo());
150153
return parameters;
151154
}
155+
156+
void ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps,
157+
MlirOperation *executableOps) {
158+
assert(!mlirModuleIsNull(module) && "module cannot be nullptr");
159+
assert(numOps && "numOps cannot be nullptr");
160+
161+
mlir::ModuleOp moduleOp = unwrap(module);
162+
llvm::SmallVector<ExecutableVariantOp> executableVariantOps =
163+
mlir::iree_compiler::getExecutableVariantOps(moduleOp);
164+
165+
if (!executableOps) {
166+
*numOps = executableVariantOps.size();
167+
return;
168+
}
169+
170+
assert(
171+
*numOps == executableVariantOps.size() &&
172+
"*numOps must match the number of elements in the executableVariantOps");
173+
174+
for (size_t i = 0, e = executableVariantOps.size(); i < e; ++i) {
175+
executableOps[i] = wrap(executableVariantOps[i]);
176+
}
177+
}
178+
179+
void ireeCodegenQueryMMAIntrinsics(MlirOperation op, size_t *numIntrinsics,
180+
uint32_t *mmaIntrinsics) {
181+
assert(numIntrinsics && "numIntrinsics cannot be nullptr");
182+
183+
mlir::Operation *mlirOp = unwrap(op);
184+
auto variantOp = llvm::dyn_cast_if_present<ExecutableVariantOp>(mlirOp);
185+
assert(variantOp && "operation is not a ExecutableVariantOp");
186+
187+
llvm::SmallVector<MMAIntrinsic> intrinsics =
188+
mlir::iree_compiler::queryMMAIntrinsics(variantOp);
189+
if (!mmaIntrinsics) {
190+
*numIntrinsics = intrinsics.size();
191+
return;
192+
}
193+
194+
assert(*numIntrinsics == intrinsics.size() &&
195+
"*numIntrinsics must match the number of elements in the intrinsics");
196+
197+
for (size_t i = 0, e = intrinsics.size(); i < e; ++i) {
198+
mmaIntrinsics[i] = static_cast<uint32_t>(intrinsics[i]);
199+
}
200+
}

compiler/src/iree/compiler/API/api_exports.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ extern void ireeCodegenCompilationInfoAttrGetTypeID();
2424
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
2525
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
2626
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
27+
extern void ireeCodegenGetExecutableVariantOps();
28+
extern void ireeCodegenQueryMMAIntrinsics();
2729
extern void ireeCodegenTranslationInfoAttrGet();
2830
extern void ireeCodegenTranslationInfoAttrGetParameters();
2931
extern void ireeCodegenTranslationInfoAttrGetTypeID();

compiler/src/iree/compiler/API/api_exports.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ EXPORTS
1414
ireeCodegenDispatchLoweringPassPipelineAttrGet
1515
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1616
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
17+
ireeCodegenGetExecutableVariantOps
18+
ireeCodegenQueryMMAIntrinsics
1719
ireeCodegenTranslationInfoAttrGet
1820
ireeCodegenTranslationInfoAttrGetParameters
1921
ireeCodegenTranslationInfoAttrGetTypeID

compiler/src/iree/compiler/API/api_exports.ld

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ VER_0 {
1515
ireeCodegenDispatchLoweringPassPipelineAttrGet;
1616
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
1717
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
18+
ireeCodegenGetExecutableVariantOps;
19+
ireeCodegenQueryMMAIntrinsics;
1820
ireeCodegenTranslationInfoAttrGet;
1921
ireeCodegenTranslationInfoAttrGetParameters;
2022
ireeCodegenTranslationInfoAttrGetTypeID;

compiler/src/iree/compiler/API/api_exports.macos.lst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ _ireeCodegenCompilationInfoAttrGetTypeID
1313
_ireeCodegenDispatchLoweringPassPipelineAttrGet
1414
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1515
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
16+
_ireeCodegenGetExecutableVariantOps
17+
_ireeCodegenQueryMMAIntrinsics
1618
_ireeCodegenTranslationInfoAttrGet
1719
_ireeCodegenTranslationInfoAttrGetParameters
1820
_ireeCodegenTranslationInfoAttrGetTypeID

compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func) {
10301030

10311031
SmallVector<IREE::HAL::ExecutableVariantOp>
10321032
getExecutableVariantOps(mlir::ModuleOp moduleOp) {
1033-
llvm::SmallVector<IREE::HAL::ExecutableVariantOp> executableVariantOps;
1033+
SmallVector<IREE::HAL::ExecutableVariantOp> executableVariantOps;
10341034
moduleOp.walk([&](IREE::HAL::ExecutableVariantOp executableOp) {
10351035
executableVariantOps.push_back(executableOp);
10361036
});
@@ -1039,7 +1039,7 @@ getExecutableVariantOps(mlir::ModuleOp moduleOp) {
10391039

10401040
SmallVector<IREE::GPU::MMAIntrinsic>
10411041
queryMMAIntrinsics(IREE::HAL::ExecutableVariantOp executableOp) {
1042-
llvm::SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
1042+
SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
10431043
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(executableOp)) {
10441044
mmaIntrinsics = llvm::map_to_vector(
10451045
target.getWgp().getMma(),

0 commit comments

Comments
 (0)