Skip to content

Commit 7460fcd

Browse files
authored
[Codegen][Tuner] expose python binding to query target info (iree-org#21782)
Motivated by issue iree-org#2048, this PR exposes the python bindings to query relevant target info, which will be used to do constraint generation in the tuner. --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent 639c7cf commit 7460fcd

File tree

10 files changed

+152
-6
lines changed

10 files changed

+152
-6
lines changed

compiler/bindings/c/iree/compiler/dialects/iree_gpu.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,18 @@ struct ireeGPUMMASingleSubgroupLayout {
144144
MLIR_CAPI_EXPORTED ireeGPUMMASingleSubgroupLayout
145145
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment);
146146

147+
struct ireeGPUTargetInfo {
148+
MlirIdentifier arch; // E.g., "gfx942".
149+
MlirAttribute subgroupSizeChoices; // Subgroup size choices.
150+
MlirAttribute maxWorkgroupSizes; // Max threads per X/Y/Z dimension.
151+
int64_t maxThreadCountPerWorkgroup; // Max threads per workgroup.
152+
int64_t maxWorkgroupMemoryBytes; // Max workgroup memory.
153+
};
154+
155+
// Queries GPU target info from the given `ExecutableTargetAttr` attribute.
156+
MLIR_CAPI_EXPORTED ireeGPUTargetInfo
157+
ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr);
158+
147159
#ifdef __cplusplus
148160
}
149161
#endif

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,38 @@ NB_MODULE(_ireeCompilerDialects, m) {
504504
return std::nullopt;
505505
});
506506

507+
//===-------------------------------------------------------------------===//
508+
// Binding to query target info
509+
//===-------------------------------------------------------------------===//
510+
511+
py::class_<ireeGPUTargetInfo>(iree_gpu_module, "TargetInfo")
512+
.def_prop_ro("arch",
513+
[](const ireeGPUTargetInfo &self) -> std::string {
514+
MlirStringRef strRef = mlirIdentifierStr(self.arch);
515+
return std::string(strRef.data, strRef.length);
516+
})
517+
.def_prop_ro("subgroup_size_choices",
518+
[](const ireeGPUTargetInfo &self) -> std::vector<int64_t> {
519+
return getIntArrayAttrValues(self.subgroupSizeChoices);
520+
})
521+
.def_prop_ro("max_thread_count_per_workgroup",
522+
[](const ireeGPUTargetInfo &self) -> int64_t {
523+
return self.maxThreadCountPerWorkgroup;
524+
})
525+
.def_prop_ro("max_workgroup_sizes",
526+
[](const ireeGPUTargetInfo &self) -> std::vector<int64_t> {
527+
return getIntArrayAttrValues(self.maxWorkgroupSizes);
528+
})
529+
.def_prop_ro("max_workgroup_memory_bytes",
530+
[](const ireeGPUTargetInfo &self) -> int64_t {
531+
return self.maxWorkgroupMemoryBytes;
532+
});
533+
534+
iree_gpu_module.def(
535+
"get_gpu_target_info", &ireeHALExecutableTargetAttrGetGPUTargetInfo,
536+
"Extracts GPU target information from an executable target attribute.",
537+
py::arg("executable_target_attr"));
538+
507539
//===-------------------------------------------------------------------===//
508540
// Binding to utility function getSingleSubgroupLayout
509541
//===-------------------------------------------------------------------===//
@@ -592,12 +624,7 @@ NB_MODULE(_ireeCompilerDialects, m) {
592624
});
593625

594626
iree_codegen_module.def(
595-
"get_attention_op_detail",
596-
[](MlirAffineMap q, MlirAffineMap k, MlirAffineMap v, MlirAffineMap o) {
597-
ireeCodegenAttentionOpDetail result =
598-
ireeCodegenGetAttentionOpDetail(q, k, v, o);
599-
return result;
600-
},
627+
"get_attention_op_detail", &ireeCodegenGetAttentionOpDetail,
601628
"Infers the structure of an attention operation from affine indexing "
602629
"maps.",
603630
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("o"));

compiler/bindings/python/test/ir/dialects_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,68 @@ def compilation_info():
391391
assert compilation_info is not None
392392
assert compilation_info.lowering_config == lowering_config
393393
assert compilation_info.translation_info == translation_info
394+
395+
396+
@run
397+
def gpu_target_info_attribute_parsing():
398+
mlir_string = """
399+
hal.executable private @main_dispatch_0 {
400+
hal.executable.variant public @rocm_hsaco_fb
401+
target(<"rocm", "rocm-hsaco-fb",
402+
{
403+
abi = "hip",
404+
iree_codegen.target_info = #iree_gpu.target<
405+
arch = "gfx942",
406+
features = "",
407+
wgp = <
408+
compute = fp64,
409+
storage = b64,
410+
subgroup = none,
411+
dot = none,
412+
mma = [<MFMA_F32_16x16x4_F32>],
413+
subgroup_size_choices = [32, 64],
414+
max_workgroup_sizes = [256, 512, 1024],
415+
max_thread_count_per_workgroup = 1024,
416+
max_workgroup_memory_bytes = 65536,
417+
max_workgroup_counts = [256, 512, 1024]
418+
>
419+
>
420+
}>
421+
) {
422+
}
423+
}
424+
"""
425+
426+
module = ir.Module.parse(mlir_string)
427+
variant_op_list = iree_codegen.get_executable_variant_ops(module)
428+
assert len(variant_op_list) == 1, "Expect one executable variant op"
429+
variant_op = variant_op_list[0]
430+
executable_variant_op = variant_op.opview
431+
target = executable_variant_op.target
432+
gpu_target_info = iree_gpu.get_gpu_target_info(target)
433+
434+
arch = gpu_target_info.arch
435+
assert arch == "gfx942", f"Expected arch 'gfx942', got '{arch}'"
436+
437+
subgroup_size_choices = gpu_target_info.subgroup_size_choices
438+
assert subgroup_size_choices == [
439+
32,
440+
64,
441+
], f"Expected subgroup_size_choice [32, 64], got {subgroup_size_choices}"
442+
443+
max_thread_count = gpu_target_info.max_thread_count_per_workgroup
444+
assert (
445+
max_thread_count == 1024
446+
), f"Expected max_thread_count_per_workgroup 1024, got {max_thread_count}"
447+
448+
max_memory_bytes = gpu_target_info.max_workgroup_memory_bytes
449+
assert (
450+
max_memory_bytes == 65536
451+
), f"Expected max_workgroup_memory_bytes 65536, got {max_memory_bytes}"
452+
453+
max_workgroup_sizes = gpu_target_info.max_workgroup_sizes
454+
assert max_workgroup_sizes == [
455+
256,
456+
512,
457+
1024,
458+
], f"Expected max_workgroup_sizes [256, 512, 1024], got {max_workgroup_sizes}"

compiler/src/iree/compiler/API/Internal/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ iree_compiler_cc_library(
156156
deps = [
157157
"//compiler/bindings/c:headers",
158158
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
159+
"//compiler/src/iree/compiler/Codegen/Utils",
159160
"@llvm-project//mlir:CAPIIR",
160161
"@llvm-project//mlir:CAPIIRHeaders",
161162
"@llvm-project//mlir:IR",

compiler/src/iree/compiler/API/Internal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ iree_cc_library(
135135
MLIRCAPIIR
136136
MLIRIR
137137
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
138+
iree::compiler::Codegen::Utils
138139
iree::compiler::bindings::c::headers
139140
PUBLIC
140141
)

compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
1111
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1212
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
13+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1314
#include "iree/compiler/dialects/iree_gpu.h"
1415
#include "mlir-c/BuiltinAttributes.h"
1516
#include "mlir-c/IR.h"
@@ -387,3 +388,37 @@ ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment) {
387388
result.element = wrap(builder.getI64ArrayAttr(layout.element));
388389
return result;
389390
}
391+
392+
ireeGPUTargetInfo
393+
ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr) {
394+
assert(!mlirAttributeIsNull(attr) && "attr cannot be null");
395+
auto executableTargetAttr =
396+
llvm::cast<mlir::iree_compiler::IREE::HAL::ExecutableTargetAttr>(
397+
unwrap(attr));
398+
399+
ireeGPUTargetInfo targetInfo = {};
400+
mlir::MLIRContext *context = executableTargetAttr.getContext();
401+
mlir::iree_compiler::IREE::GPU::TargetAttr gpuTargetAttr =
402+
mlir::iree_compiler::getGPUTargetAttr(context, executableTargetAttr);
403+
404+
if (!gpuTargetAttr) {
405+
return targetInfo;
406+
}
407+
408+
targetInfo.arch =
409+
wrap(mlir::StringAttr::get(context, gpuTargetAttr.getArch()));
410+
mlir::iree_compiler::IREE::GPU::TargetWgpAttr wgpAttr =
411+
gpuTargetAttr.getWgp();
412+
mlir::Builder builder = mlir::OpBuilder(context);
413+
414+
targetInfo.subgroupSizeChoices =
415+
wrap(builder.getI32ArrayAttr(wgpAttr.getSubgroupSizeChoices()));
416+
targetInfo.maxWorkgroupSizes =
417+
wrap(builder.getI32ArrayAttr(wgpAttr.getMaxWorkgroupSizes()));
418+
419+
targetInfo.maxThreadCountPerWorkgroup =
420+
wgpAttr.getMaxThreadCountPerWorkgroup();
421+
targetInfo.maxWorkgroupMemoryBytes = wgpAttr.getMaxWorkgroupMemoryBytes();
422+
423+
return targetInfo;
424+
}

compiler/src/iree/compiler/API/api_exports.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ extern void ireeGPUPipelineOptionsAttrGetUseIgemmConvolution();
113113
extern void ireeGPUReorderWorkgroupsStrategyAttrGet();
114114
extern void ireeGPUReorderWorkgroupsStrategyAttrGetTypeID();
115115
extern void ireeGPUReorderWorkgroupsStrategyAttrGetValue();
116+
extern void ireeHALExecutableTargetAttrGetGPUTargetInfo();
116117
extern void ireeMlirLspServerRunMain();
117118
extern void ireeOptRunMain();
118119
extern void ireeReduceRunMain();
@@ -1027,6 +1028,7 @@ uintptr_t __iree_compiler_hidden_force_extern() {
10271028
x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGet;
10281029
x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGetTypeID;
10291030
x += (uintptr_t)&ireeGPUReorderWorkgroupsStrategyAttrGetValue;
1031+
x += (uintptr_t)&ireeHALExecutableTargetAttrGetGPUTargetInfo;
10301032
x += (uintptr_t)&ireeMlirLspServerRunMain;
10311033
x += (uintptr_t)&ireeOptRunMain;
10321034
x += (uintptr_t)&ireeReduceRunMain;

compiler/src/iree/compiler/API/api_exports.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ EXPORTS
103103
ireeGPUReorderWorkgroupsStrategyAttrGet
104104
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID
105105
ireeGPUReorderWorkgroupsStrategyAttrGetValue
106+
ireeHALExecutableTargetAttrGetGPUTargetInfo
106107
ireeMlirLspServerRunMain
107108
ireeOptRunMain
108109
ireeReduceRunMain

compiler/src/iree/compiler/API/api_exports.ld

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ VER_0 {
104104
ireeGPUReorderWorkgroupsStrategyAttrGet;
105105
ireeGPUReorderWorkgroupsStrategyAttrGetTypeID;
106106
ireeGPUReorderWorkgroupsStrategyAttrGetValue;
107+
ireeHALExecutableTargetAttrGetGPUTargetInfo;
107108
ireeMlirLspServerRunMain;
108109
ireeOptRunMain;
109110
ireeReduceRunMain;

compiler/src/iree/compiler/API/api_exports.macos.lst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ _ireeGPUPipelineOptionsAttrGetUseIgemmConvolution
102102
_ireeGPUReorderWorkgroupsStrategyAttrGet
103103
_ireeGPUReorderWorkgroupsStrategyAttrGetTypeID
104104
_ireeGPUReorderWorkgroupsStrategyAttrGetValue
105+
_ireeHALExecutableTargetAttrGetGPUTargetInfo
105106
_ireeMlirLspServerRunMain
106107
_ireeOptRunMain
107108
_ireeReduceRunMain

0 commit comments

Comments
 (0)