1010#include " iree/compiler/dialects/iree_codegen.h"
1111#include " iree/compiler/dialects/iree_gpu.h"
1212#include " mlir-c/BuiltinAttributes.h"
13+ #include " mlir-c/BuiltinTypes.h"
1314#include " mlir-c/IR.h"
1415#include " mlir/Bindings/Python/Nanobind.h"
1516#include " mlir/Bindings/Python/NanobindAdaptors.h"
@@ -20,6 +21,9 @@ static const char *kCodegenModuleImportPath =
2021static const char *kGpuModuleImportPath =
2122 MAKE_MLIR_PYTHON_QUALNAME (" dialects.iree_gpu" );
2223
24+ static const char *kMMAIntrinsicEnumName = " MMAIntrinsic" ;
25+ static const char *kVirtualMMAIntrinsicEnumName = " VirtualMMAIntrinsic" ;
26+
2327namespace py = nanobind;
2428using namespace nanobind ::literals;
2529using namespace mlir ::python::nanobind_adaptors;
@@ -42,7 +46,7 @@ ireeCodegenQueryMMAIntrinsicsBinding(MlirOperation op) {
4246 ireeCodegenQueryMMAIntrinsics (op, &numMMAs, mmaIntrinsics.data ());
4347
4448 py::object mmaIntrinsicEnum =
45- py::module_::import_ (kGpuModuleImportPath ).attr (" MMAIntrinsic " );
49+ py::module_::import_ (kGpuModuleImportPath ).attr (kMMAIntrinsicEnumName );
4650 std::vector<py::object> mmaList (numMMAs);
4751 for (size_t i = 0 ; i < numMMAs; ++i) {
4852 mmaList[i] = mmaIntrinsicEnum (mmaIntrinsics[i]);
@@ -325,7 +329,7 @@ NB_MODULE(_ireeCompilerDialects, m) {
325329 uint32_t rawValue =
326330 ireeGPUMMAIntrinsicAttrGetValue (self);
327331 return py::module_::import_ (kGpuModuleImportPath )
328- .attr (" MMAIntrinsic " )(rawValue);
332+ .attr (kMMAIntrinsicEnumName )(rawValue);
329333 })
330334 .def_property_readonly (" mma" , [](MlirAttribute self) -> MlirAttribute {
331335 uint32_t value = ireeGPUMMAIntrinsicAttrGetValue (self);
@@ -377,7 +381,7 @@ NB_MODULE(_ireeCompilerDialects, m) {
377381
378382 static py::object virtualEnum =
379383 py::module_::import_ (kGpuModuleImportPath )
380- .attr (" VirtualMMAIntrinsic " );
384+ .attr (kVirtualMMAIntrinsicEnumName );
381385
382386 std::vector<py::object> result;
383387 for (int64_t val : getIntArrayAttrValues (rawArrayAttr)) {
@@ -404,13 +408,13 @@ NB_MODULE(_ireeCompilerDialects, m) {
404408 " Gets an #iree_gpu.virtual_mma_intrinsic from parameters." )
405409 .def_property_readonly (" raw_value" ,
406410 ireeGPUVirtualMMAIntrinsicAttrGetValue)
407- .def_property_readonly (" value " ,
408- [](MlirAttribute self) -> py::object {
409- uint32_t rawValue =
410- ireeGPUVirtualMMAIntrinsicAttrGetValue (self);
411- return py::module_::import_ (kGpuModuleImportPath )
412- .attr (" VirtualMMAIntrinsic " )(rawValue);
413- })
411+ .def_property_readonly (
412+ " value " ,
413+ [](MlirAttribute self) -> py::object {
414+ uint32_t rawValue = ireeGPUVirtualMMAIntrinsicAttrGetValue (self);
415+ return py::module_::import_ (kGpuModuleImportPath )
416+ .attr (kVirtualMMAIntrinsicEnumName )(rawValue);
417+ })
414418 .def_property_readonly (" mma" , [](MlirAttribute self) -> MlirAttribute {
415419 uint32_t value = ireeGPUVirtualMMAIntrinsicAttrGetValue (self);
416420 return ireeGPUVirtualMMAAttrGet (mlirAttributeGetContext (self), value);
@@ -509,6 +513,43 @@ NB_MODULE(_ireeCompilerDialects, m) {
509513 // ===-------------------------------------------------------------------===//
510514
511515 py::class_<ireeGPUTargetInfo>(iree_gpu_module, " TargetInfo" )
516+ .def (
517+ " __init__" ,
518+ [](ireeGPUTargetInfo *self, MlirContext context,
519+ const std::string &arch,
520+ const std::vector<int32_t > &subgroupChoices,
521+ const std::vector<int32_t > &workgroupSizes, int32_t threadCount,
522+ int32_t memoryBytes, const py::list &mmaIntrinsicObjs) {
523+ std::vector<mma_intrinsic_enum_t > mmaIntrinsicVals;
524+ py::module_ gpuModule = py::module_::import_ (kGpuModuleImportPath );
525+ py::object mmaIntrinsicClass =
526+ gpuModule.attr (kMMAIntrinsicEnumName );
527+ py::object virtualMmaIntrinsicClass =
528+ gpuModule.attr (kVirtualMMAIntrinsicEnumName );
529+
530+ for (py::handle item : mmaIntrinsicObjs) {
531+ if (!py::isinstance (item, mmaIntrinsicClass) &&
532+ !py::isinstance (item, virtualMmaIntrinsicClass)) {
533+ throw py::type_error (" All items must be MMA atributes" );
534+ }
535+ mmaIntrinsicVals.push_back (
536+ py::cast<mma_intrinsic_enum_t >(item.attr (" value" )));
537+ }
538+
539+ *self = ireeGPUTargetInfoGet (
540+ context, arch.c_str (), subgroupChoices.data (),
541+ subgroupChoices.size (), workgroupSizes.data (),
542+ workgroupSizes.size (), threadCount, memoryBytes,
543+ mmaIntrinsicVals.data (), mmaIntrinsicVals.size ());
544+ },
545+ " context" _a, " arch" _a, " subgroup_size_choices" _a,
546+ " max_workgroup_sizes" _a, " max_thread_count_per_workgroup" _a,
547+ " max_workgroup_memory_bytes" _a, " mma_intrinsics" _a = py::list{},
548+ " Create a GPUTargetInfo with the given parameters" )
549+ .def_static (
550+ " get_gpu_target_info" , &ireeHALExecutableTargetAttrGetGPUTargetInfo,
551+ " executable_target_attr" _a,
552+ " Get GPU target information from an executable target attribute" )
512553 .def_prop_ro (" arch" ,
513554 [](const ireeGPUTargetInfo &self) -> std::string {
514555 MlirStringRef strRef = mlirIdentifierStr (self.arch );
@@ -529,12 +570,44 @@ NB_MODULE(_ireeCompilerDialects, m) {
529570 .def_prop_ro (" max_workgroup_memory_bytes" ,
530571 [](const ireeGPUTargetInfo &self) -> int64_t {
531572 return self.maxWorkgroupMemoryBytes ;
532- });
573+ })
574+ .def_prop_ro (
575+ " mma_intrinsics" , [](const ireeGPUTargetInfo &self) -> py::list {
576+ if (mlirAttributeIsNull (self.mmaIntrinsics ) ||
577+ !mlirAttributeIsAArray (self.mmaIntrinsics )) {
578+ return py::list ();
579+ }
533580
534- iree_gpu_module.def (
535- " get_gpu_target_info" , &ireeHALExecutableTargetAttrGetGPUTargetInfo,
536- " Extracts GPU target information from an executable target attribute." ,
537- py::arg (" executable_target_attr" ));
581+ size_t numElements =
582+ mlirArrayAttrGetNumElements (self.mmaIntrinsics );
583+
584+ std::vector<mma_intrinsic_enum_t > mmaIntrinsicVals (numElements);
585+ // Use uint8_t instead of bool because std::vector<bool> is a
586+ // specialized template that doesn't provide .data() method.
587+ std::vector<uint8_t > virtualMmaIntrinsicTags (numElements);
588+ ireeGPUTargetInfoGetMMAIntrinsics (self.mmaIntrinsics ,
589+ mmaIntrinsicVals.data (),
590+ virtualMmaIntrinsicTags.data ());
591+
592+ py::list intrinsics;
593+ py::module_ gpuModule = py::module_::import_ (kGpuModuleImportPath );
594+ py::object mmaIntrinsicEnum = gpuModule.attr (kMMAIntrinsicEnumName );
595+ py::object virtualMmaIntrinsicEnum =
596+ gpuModule.attr (kVirtualMMAIntrinsicEnumName );
597+
598+ for (size_t i = 0 ; i < numElements; ++i) {
599+ if (virtualMmaIntrinsicTags[i]) {
600+ py::object virtualMmaIntrinsic =
601+ virtualMmaIntrinsicEnum (mmaIntrinsicVals[i]);
602+ intrinsics.append (virtualMmaIntrinsic);
603+ continue ;
604+ }
605+ py::object mmaIntrinsic = mmaIntrinsicEnum (mmaIntrinsicVals[i]);
606+ intrinsics.append (mmaIntrinsic);
607+ }
608+
609+ return intrinsics;
610+ });
538611
539612 // ===-------------------------------------------------------------------===//
540613 // Binding to utility function getSingleSubgroupLayout
0 commit comments