1010#include < type_traits>
1111#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1212#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
13+ #include " iree/compiler/Codegen/Utils/GPUUtils.h"
1314#include " iree/compiler/dialects/iree_codegen.h"
1415#include " mlir-c/BuiltinAttributes.h"
1516#include " mlir-c/IR.h"
@@ -24,6 +25,8 @@ using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
2425using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
2526using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
2627using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
28+ using mlir::iree_compiler::IREE::GPU::MMAIntrinsic;
29+ using mlir::iree_compiler::IREE::HAL::ExecutableVariantOp;
2730
2831bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr (
2932 MlirAttribute attr) {
@@ -149,3 +152,49 @@ ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) {
149152 parameters.translationInfo = wrap (compilationInfo.getTranslationInfo ());
150153 return parameters;
151154}
155+
156+ void ireeCodegenGetExecutableVariantOps (MlirModule module , size_t *numOps,
157+ MlirOperation *executableOps) {
158+ assert (!mlirModuleIsNull (module ) && " module cannot be nullptr" );
159+ assert (numOps && " numOps cannot be nullptr" );
160+
161+ mlir::ModuleOp moduleOp = unwrap (module );
162+ llvm::SmallVector<ExecutableVariantOp> executableVariantOps =
163+ mlir::iree_compiler::getExecutableVariantOps (moduleOp);
164+
165+ if (!executableOps) {
166+ *numOps = executableVariantOps.size ();
167+ return ;
168+ }
169+
170+ assert (
171+ *numOps == executableVariantOps.size () &&
172+ " *numOps must match the number of elements in the executableVariantOps" );
173+
174+ for (size_t i = 0 , e = executableVariantOps.size (); i < e; ++i) {
175+ executableOps[i] = wrap (executableVariantOps[i]);
176+ }
177+ }
178+
179+ void ireeCodegenQueryMMAIntrinsics (MlirOperation op, size_t *numIntrinsics,
180+ uint32_t *mmaIntrinsics) {
181+ assert (numIntrinsics && " numIntrinsics cannot be nullptr" );
182+
183+ mlir::Operation *mlirOp = unwrap (op);
184+ auto variantOp = llvm::dyn_cast_if_present<ExecutableVariantOp>(mlirOp);
185+ assert (variantOp && " operation is not a ExecutableVariantOp" );
186+
187+ llvm::SmallVector<MMAIntrinsic> intrinsics =
188+ mlir::iree_compiler::queryMMAIntrinsics (variantOp);
189+ if (!mmaIntrinsics) {
190+ *numIntrinsics = intrinsics.size ();
191+ return ;
192+ }
193+
194+ assert (*numIntrinsics == intrinsics.size () &&
195+ " *numIntrinsics must match the number of elements in the intrinsics" );
196+
197+ for (size_t i = 0 , e = intrinsics.size (); i < e; ++i) {
198+ mmaIntrinsics[i] = static_cast <uint32_t >(intrinsics[i]);
199+ }
200+ }
0 commit comments