Skip to content

Commit c4095e5

Browse files
bangtianliuhhkit
authored andcommitted
[Codegen][Tuner]: improve python binding to query target info (iree-org#21812)
On top of iree-org#21782, this PR enhances the IREE GPU Python bindings by adding constructor support to the `TargetInfo` class. Also, add the support of querying the mma intrinsics supported by arch. Issue: nod-ai/shark-ai#2048 --------- Signed-off-by: Bangtian Liu <[email protected]> Signed-off-by: Ivan Ho <[email protected]>
1 parent a1ffcc1 commit c4095e5

File tree

8 files changed

+406
-42
lines changed

8 files changed

+406
-42
lines changed

compiler/bindings/c/iree/compiler/dialects/iree_gpu.h

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
extern "C" {
1515
#endif
1616

17+
// This typedef ensures consistency between the C API, C++ implementation, and
18+
// Python bindings. Update both this typedef and the static assertions if the
19+
// enum underlying types change.
20+
typedef uint32_t mma_intrinsic_enum_t;
21+
1722
// The following C API is **NOT STABLE** and likely to change in the future.
1823
// It mirrors the IREE GPU Dialect which is not stable itself.
1924

@@ -56,35 +61,37 @@ MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr);
5661

5762
MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID(void);
5863

59-
MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx,
60-
uint32_t value);
64+
MLIR_CAPI_EXPORTED MlirAttribute
65+
ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, mma_intrinsic_enum_t value);
6166

62-
MLIR_CAPI_EXPORTED uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr);
67+
MLIR_CAPI_EXPORTED mma_intrinsic_enum_t
68+
ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr);
6369

6470
MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr);
6571

6672
MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAAttrGetTypeID(void);
6773

6874
MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx,
69-
uint32_t value);
75+
mma_intrinsic_enum_t value);
7076

7177
MLIR_CAPI_EXPORTED bool
7278
ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(MlirAttribute attr);
7379

7480
MLIR_CAPI_EXPORTED MlirTypeID ireeGPUVirtualMMAIntrinsicAttrGetTypeID(void);
7581

76-
MLIR_CAPI_EXPORTED MlirAttribute
77-
ireeGPUVirtualMMAIntrinsicAttrGet(MlirContext mlirCtx, uint32_t value);
82+
MLIR_CAPI_EXPORTED MlirAttribute ireeGPUVirtualMMAIntrinsicAttrGet(
83+
MlirContext mlirCtx, mma_intrinsic_enum_t value);
7884

79-
MLIR_CAPI_EXPORTED uint32_t
85+
MLIR_CAPI_EXPORTED mma_intrinsic_enum_t
8086
ireeGPUVirtualMMAIntrinsicAttrGetValue(MlirAttribute attr);
8187

8288
MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUVirtualMMAAttr(MlirAttribute attr);
8389

8490
MLIR_CAPI_EXPORTED MlirTypeID ireeGPUVirtualMMAAttrGetTypeID(void);
8591

86-
MLIR_CAPI_EXPORTED MlirAttribute ireeGPUVirtualMMAAttrGet(MlirContext mlirCtx,
87-
uint32_t value);
92+
MLIR_CAPI_EXPORTED MlirAttribute
93+
ireeGPUVirtualMMAAttrGet(MlirContext mlirCtx, mma_intrinsic_enum_t value);
94+
8895
struct ireeGPUMMAInfo {
8996
MlirType aElementType;
9097
MlirType bElementType;
@@ -148,14 +155,32 @@ struct ireeGPUTargetInfo {
148155
MlirIdentifier arch; // E.g., "gfx942".
149156
MlirAttribute subgroupSizeChoices; // Subgroup size choices.
150157
MlirAttribute maxWorkgroupSizes; // Max threads per X/Y/Z dimension.
151-
int64_t maxThreadCountPerWorkgroup; // Max threads per workgroup.
152-
int64_t maxWorkgroupMemoryBytes; // Max workgroup memory.
158+
int32_t maxThreadCountPerWorkgroup; // Max threads per workgroup.
159+
int32_t maxWorkgroupMemoryBytes; // Max workgroup memory.
160+
MlirAttribute mmaIntrinsics; // MMA Intrinsics.
153161
};
154162

155163
// Queries GPU target info from the given `ExecutableTargetAttr` attribute.
156164
MLIR_CAPI_EXPORTED ireeGPUTargetInfo
157165
ireeHALExecutableTargetAttrGetGPUTargetInfo(MlirAttribute attr);
158166

167+
MLIR_CAPI_EXPORTED ireeGPUTargetInfo ireeGPUTargetInfoGet(
168+
MlirContext mlirCtx, const char *arch, const int32_t *subgroupChoices,
169+
size_t numSubgroupChoices, const int32_t *workgroupSizes,
170+
size_t numWorkgroupSizes, int32_t threadCount, int32_t memoryBytes,
171+
const mma_intrinsic_enum_t *mmaIntrinsics, size_t numMmaIntrinsics);
172+
173+
// Extracts MMA intrinsic values and their virtual status from an ArrayAttr.
174+
//
175+
// mmaIntrinsics: Array attribute containing MMA intrinsic attributes.
176+
// mmaIntrinsicVals: Output array for MMA intrinsic enum values.
177+
// virtualMmaIntrinsicTags: Output array - 1 if VirtualMMAIntrinsic, 0 if
178+
// MMAIntrinsic.
179+
MLIR_CAPI_EXPORTED void
180+
ireeGPUTargetInfoGetMMAIntrinsics(MlirAttribute mmaIntrinsics,
181+
mma_intrinsic_enum_t *mmaIntrinsicVals,
182+
uint8_t *virtualMmaIntrinsicTags);
183+
159184
#ifdef __cplusplus
160185
}
161186
#endif

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/dialects/iree_codegen.h"
1111
#include "iree/compiler/dialects/iree_gpu.h"
1212
#include "mlir-c/BuiltinAttributes.h"
13+
#include "mlir-c/BuiltinTypes.h"
1314
#include "mlir-c/IR.h"
1415
#include "mlir/Bindings/Python/Nanobind.h"
1516
#include "mlir/Bindings/Python/NanobindAdaptors.h"
@@ -20,6 +21,9 @@ static const char *kCodegenModuleImportPath =
2021
static const char *kGpuModuleImportPath =
2122
MAKE_MLIR_PYTHON_QUALNAME("dialects.iree_gpu");
2223

24+
static const char *kMMAIntrinsicEnumName = "MMAIntrinsic";
25+
static const char *kVirtualMMAIntrinsicEnumName = "VirtualMMAIntrinsic";
26+
2327
namespace py = nanobind;
2428
using namespace nanobind::literals;
2529
using namespace mlir::python::nanobind_adaptors;
@@ -42,7 +46,7 @@ ireeCodegenQueryMMAIntrinsicsBinding(MlirOperation op) {
4246
ireeCodegenQueryMMAIntrinsics(op, &numMMAs, mmaIntrinsics.data());
4347

4448
py::object mmaIntrinsicEnum =
45-
py::module_::import_(kGpuModuleImportPath).attr("MMAIntrinsic");
49+
py::module_::import_(kGpuModuleImportPath).attr(kMMAIntrinsicEnumName);
4650
std::vector<py::object> mmaList(numMMAs);
4751
for (size_t i = 0; i < numMMAs; ++i) {
4852
mmaList[i] = mmaIntrinsicEnum(mmaIntrinsics[i]);
@@ -325,7 +329,7 @@ NB_MODULE(_ireeCompilerDialects, m) {
325329
uint32_t rawValue =
326330
ireeGPUMMAIntrinsicAttrGetValue(self);
327331
return py::module_::import_(kGpuModuleImportPath)
328-
.attr("MMAIntrinsic")(rawValue);
332+
.attr(kMMAIntrinsicEnumName)(rawValue);
329333
})
330334
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
331335
uint32_t value = ireeGPUMMAIntrinsicAttrGetValue(self);
@@ -377,7 +381,7 @@ NB_MODULE(_ireeCompilerDialects, m) {
377381

378382
static py::object virtualEnum =
379383
py::module_::import_(kGpuModuleImportPath)
380-
.attr("VirtualMMAIntrinsic");
384+
.attr(kVirtualMMAIntrinsicEnumName);
381385

382386
std::vector<py::object> result;
383387
for (int64_t val : getIntArrayAttrValues(rawArrayAttr)) {
@@ -404,13 +408,13 @@ NB_MODULE(_ireeCompilerDialects, m) {
404408
"Gets an #iree_gpu.virtual_mma_intrinsic from parameters.")
405409
.def_property_readonly("raw_value",
406410
ireeGPUVirtualMMAIntrinsicAttrGetValue)
407-
.def_property_readonly("value",
408-
[](MlirAttribute self) -> py::object {
409-
uint32_t rawValue =
410-
ireeGPUVirtualMMAIntrinsicAttrGetValue(self);
411-
return py::module_::import_(kGpuModuleImportPath)
412-
.attr("VirtualMMAIntrinsic")(rawValue);
413-
})
411+
.def_property_readonly(
412+
"value",
413+
[](MlirAttribute self) -> py::object {
414+
uint32_t rawValue = ireeGPUVirtualMMAIntrinsicAttrGetValue(self);
415+
return py::module_::import_(kGpuModuleImportPath)
416+
.attr(kVirtualMMAIntrinsicEnumName)(rawValue);
417+
})
414418
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
415419
uint32_t value = ireeGPUVirtualMMAIntrinsicAttrGetValue(self);
416420
return ireeGPUVirtualMMAAttrGet(mlirAttributeGetContext(self), value);
@@ -509,6 +513,43 @@ NB_MODULE(_ireeCompilerDialects, m) {
509513
//===-------------------------------------------------------------------===//
510514

511515
py::class_<ireeGPUTargetInfo>(iree_gpu_module, "TargetInfo")
516+
.def(
517+
"__init__",
518+
[](ireeGPUTargetInfo *self, MlirContext context,
519+
const std::string &arch,
520+
const std::vector<int32_t> &subgroupChoices,
521+
const std::vector<int32_t> &workgroupSizes, int32_t threadCount,
522+
int32_t memoryBytes, const py::list &mmaIntrinsicObjs) {
523+
std::vector<mma_intrinsic_enum_t> mmaIntrinsicVals;
524+
py::module_ gpuModule = py::module_::import_(kGpuModuleImportPath);
525+
py::object mmaIntrinsicClass =
526+
gpuModule.attr(kMMAIntrinsicEnumName);
527+
py::object virtualMmaIntrinsicClass =
528+
gpuModule.attr(kVirtualMMAIntrinsicEnumName);
529+
530+
for (py::handle item : mmaIntrinsicObjs) {
531+
if (!py::isinstance(item, mmaIntrinsicClass) &&
532+
!py::isinstance(item, virtualMmaIntrinsicClass)) {
533+
throw py::type_error("All items must be MMA atributes");
534+
}
535+
mmaIntrinsicVals.push_back(
536+
py::cast<mma_intrinsic_enum_t>(item.attr("value")));
537+
}
538+
539+
*self = ireeGPUTargetInfoGet(
540+
context, arch.c_str(), subgroupChoices.data(),
541+
subgroupChoices.size(), workgroupSizes.data(),
542+
workgroupSizes.size(), threadCount, memoryBytes,
543+
mmaIntrinsicVals.data(), mmaIntrinsicVals.size());
544+
},
545+
"context"_a, "arch"_a, "subgroup_size_choices"_a,
546+
"max_workgroup_sizes"_a, "max_thread_count_per_workgroup"_a,
547+
"max_workgroup_memory_bytes"_a, "mma_intrinsics"_a = py::list{},
548+
"Create a GPUTargetInfo with the given parameters")
549+
.def_static(
550+
"get_gpu_target_info", &ireeHALExecutableTargetAttrGetGPUTargetInfo,
551+
"executable_target_attr"_a,
552+
"Get GPU target information from an executable target attribute")
512553
.def_prop_ro("arch",
513554
[](const ireeGPUTargetInfo &self) -> std::string {
514555
MlirStringRef strRef = mlirIdentifierStr(self.arch);
@@ -529,12 +570,44 @@ NB_MODULE(_ireeCompilerDialects, m) {
529570
.def_prop_ro("max_workgroup_memory_bytes",
530571
[](const ireeGPUTargetInfo &self) -> int64_t {
531572
return self.maxWorkgroupMemoryBytes;
532-
});
573+
})
574+
.def_prop_ro(
575+
"mma_intrinsics", [](const ireeGPUTargetInfo &self) -> py::list {
576+
if (mlirAttributeIsNull(self.mmaIntrinsics) ||
577+
!mlirAttributeIsAArray(self.mmaIntrinsics)) {
578+
return py::list();
579+
}
533580

534-
iree_gpu_module.def(
535-
"get_gpu_target_info", &ireeHALExecutableTargetAttrGetGPUTargetInfo,
536-
"Extracts GPU target information from an executable target attribute.",
537-
py::arg("executable_target_attr"));
581+
size_t numElements =
582+
mlirArrayAttrGetNumElements(self.mmaIntrinsics);
583+
584+
std::vector<mma_intrinsic_enum_t> mmaIntrinsicVals(numElements);
585+
// Use uint8_t instead of bool because std::vector<bool> is a
586+
// specialized template that doesn't provide .data() method.
587+
std::vector<uint8_t> virtualMmaIntrinsicTags(numElements);
588+
ireeGPUTargetInfoGetMMAIntrinsics(self.mmaIntrinsics,
589+
mmaIntrinsicVals.data(),
590+
virtualMmaIntrinsicTags.data());
591+
592+
py::list intrinsics;
593+
py::module_ gpuModule = py::module_::import_(kGpuModuleImportPath);
594+
py::object mmaIntrinsicEnum = gpuModule.attr(kMMAIntrinsicEnumName);
595+
py::object virtualMmaIntrinsicEnum =
596+
gpuModule.attr(kVirtualMMAIntrinsicEnumName);
597+
598+
for (size_t i = 0; i < numElements; ++i) {
599+
if (virtualMmaIntrinsicTags[i]) {
600+
py::object virtualMmaIntrinsic =
601+
virtualMmaIntrinsicEnum(mmaIntrinsicVals[i]);
602+
intrinsics.append(virtualMmaIntrinsic);
603+
continue;
604+
}
605+
py::object mmaIntrinsic = mmaIntrinsicEnum(mmaIntrinsicVals[i]);
606+
intrinsics.append(mmaIntrinsic);
607+
}
608+
609+
return intrinsics;
610+
});
538611

539612
//===-------------------------------------------------------------------===//
540613
// Binding to utility function getSingleSubgroupLayout

0 commit comments

Comments
 (0)