Skip to content

Commit 4477091

Browse files
authored
[python][tuner] Add bindings for iree_codegen.compilation_info (iree-org#19129)
1 parent bc23e59 commit 4477091

File tree

8 files changed

+125
-2
lines changed

8 files changed

+125
-2
lines changed

compiler/bindings/c/iree/compiler/dialects/iree_codegen.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenTranslationInfoAttrGet(
5050
MLIR_CAPI_EXPORTED ireeCodegenTranslationInfoParameters
5151
ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr);
5252

53+
MLIR_CAPI_EXPORTED bool
54+
ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr);
55+
56+
MLIR_CAPI_EXPORTED MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID(void);
57+
58+
struct ireeCodegenCompilationInfoParameters {
59+
MlirAttribute loweringConfig;
60+
MlirAttribute translationInfo;
61+
};
62+
63+
MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet(
64+
MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters);
65+
66+
MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters
67+
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr);
68+
5369
#ifdef __cplusplus
5470
}
5571
#endif

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
8383
"cls"_a, "pass_pipeline"_a, "codegen_spec"_a = py::none(),
8484
"workgroup_size"_a = py::none(), "subgroup_size"_a = py::none(),
8585
"configuration"_a = py::none(), py::kw_only(), "ctx"_a = py::none(),
86-
"Gets an #iree_codegen.translation_info from "
87-
"parameters.")
86+
"Gets an #iree_codegen.translation_info from parameters.")
8887
.def_property_readonly(
8988
"pass_pipeline",
9089
[](MlirAttribute self) -> MlirAttribute {
@@ -124,6 +123,37 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
124123
return parameters.configuration;
125124
});
126125

126+
//===-------------------------------------------------------------------===//
127+
// CodegenCompilationInfoAttr
128+
//===-------------------------------------------------------------------===//
129+
130+
mlir_attribute_subclass(iree_codegen_module, "CompilationInfoAttr",
131+
ireeAttributeIsACodegenCompilationInfoAttr,
132+
ireeCodegenCompilationInfoAttrGetTypeID)
133+
.def_classmethod(
134+
"get",
135+
[](const py::object &, MlirAttribute loweringConfig,
136+
MlirAttribute translationInfo, MlirContext ctx) {
137+
ireeCodegenCompilationInfoParameters parameters = {};
138+
parameters.loweringConfig = loweringConfig;
139+
parameters.translationInfo = translationInfo;
140+
return ireeCodegenCompilationInfoAttrGet(ctx, parameters);
141+
},
142+
"cls"_a, "lowering_config"_a, "translation_info"_a,
143+
"ctx"_a = py::none(),
144+
"Gets an #iree_codegen.compilation_info from parameters.")
145+
.def_property_readonly(
146+
"lowering_config",
147+
[](MlirAttribute self) -> MlirAttribute {
148+
auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self);
149+
return parameters.loweringConfig;
150+
})
151+
.def_property_readonly(
152+
"translation_info", [](MlirAttribute self) -> MlirAttribute {
153+
auto parameters = ireeCodegenCompilationInfoAttrGetParameters(self);
154+
return parameters.translationInfo;
155+
});
156+
127157
//===--------------------------------------------------------------------===//
128158

129159
auto iree_gpu_module =

compiler/bindings/python/test/ir/dialects_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,20 @@ def lowering_config_attr():
215215
assert lowering_config is not None
216216

217217
assert lowering_config.attributes == attributes
218+
219+
220+
@run
221+
def compilation_info():
222+
attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])})
223+
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes)
224+
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
225+
iree_codegen.DispatchLoweringPassPipeline.None_
226+
)
227+
translation_info = iree_codegen.TranslationInfoAttr.get(pipeline_attr)
228+
229+
compilation_info = iree_codegen.CompilationInfoAttr.get(
230+
lowering_config, translation_info
231+
)
232+
assert compilation_info is not None
233+
assert compilation_info.lowering_config == lowering_config
234+
assert compilation_info.translation_info == translation_info

compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@
99
#include <optional>
1010
#include <type_traits>
1111
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
12+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
1213
#include "iree/compiler/dialects/iree_codegen.h"
1314
#include "mlir-c/BuiltinAttributes.h"
1415
#include "mlir-c/IR.h"
1516
#include "mlir/CAPI/IR.h"
1617
#include "mlir/CAPI/Support.h"
18+
#include "mlir/IR/Attributes.h"
1719
#include "mlir/IR/BuiltinAttributes.h"
20+
#include "mlir/IR/MLIRContext.h"
1821

22+
using mlir::iree_compiler::IREE::Codegen::CompilationInfoAttr;
1923
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
2024
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
25+
using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
2126
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
2227

2328
bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
@@ -109,3 +114,38 @@ ireeCodegenTranslationInfoAttrGetParameters(MlirAttribute attr) {
109114

110115
return parameters;
111116
}
117+
118+
bool ireeAttributeIsACodegenCompilationInfoAttr(MlirAttribute attr) {
119+
return llvm::isa<CompilationInfoAttr>(unwrap(attr));
120+
}
121+
122+
MlirTypeID ireeCodegenCompilationInfoAttrGetTypeID() {
123+
return wrap(CompilationInfoAttr::getTypeID());
124+
}
125+
126+
MlirAttribute ireeCodegenCompilationInfoAttrGet(
127+
MlirContext mlirCtx, ireeCodegenCompilationInfoParameters parameters) {
128+
assert(!mlirAttributeIsNull(parameters.loweringConfig) &&
129+
"Invalid lowering config attr");
130+
assert(
131+
!mlirAttributeIsNull(parameters.translationInfo) &&
132+
ireeAttributeIsACodegenTranslationInfoAttr(parameters.translationInfo) &&
133+
"Invalid translation info attr");
134+
135+
auto loweringConfig = llvm::cast<LoweringConfigAttrInterface>(
136+
unwrap(parameters.loweringConfig));
137+
auto translationInfo =
138+
llvm::cast<TranslationInfoAttr>(unwrap(parameters.translationInfo));
139+
140+
mlir::MLIRContext *ctx = unwrap(mlirCtx);
141+
return wrap(CompilationInfoAttr::get(ctx, loweringConfig, translationInfo));
142+
}
143+
144+
ireeCodegenCompilationInfoParameters
145+
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) {
146+
auto compilationInfo = llvm::cast<CompilationInfoAttr>(unwrap(attr));
147+
ireeCodegenCompilationInfoParameters parameters = {};
148+
parameters.loweringConfig = wrap(compilationInfo.getLoweringConfig());
149+
parameters.translationInfo = wrap(compilationInfo.getTranslationInfo());
150+
return parameters;
151+
}

compiler/src/iree/compiler/API/api_exports.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010

1111
#include <stdint.h>
1212

13+
extern void ireeAttributeIsACodegenCompilationInfoAttr();
1314
extern void ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr();
1415
extern void ireeAttributeIsACodegenTranslationInfoAttr();
1516
extern void ireeAttributeIsAGPULoweringConfigAttr();
1617
extern void ireeAttributeIsAGPUMMAAttr();
1718
extern void ireeAttributeIsAGPUMMAIntrinsicAttr();
1819
extern void ireeAttributeIsAGPUPipelineOptionsAttr();
1920
extern void ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr();
21+
extern void ireeCodegenCompilationInfoAttrGet();
22+
extern void ireeCodegenCompilationInfoAttrGetParameters();
23+
extern void ireeCodegenCompilationInfoAttrGetTypeID();
2024
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
2125
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
2226
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
@@ -868,13 +872,17 @@ extern void mlirVectorTypeIsScalable();
868872

869873
uintptr_t __iree_compiler_hidden_force_extern() {
870874
uintptr_t x = 0;
875+
x += (uintptr_t)&ireeAttributeIsACodegenCompilationInfoAttr;
871876
x += (uintptr_t)&ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
872877
x += (uintptr_t)&ireeAttributeIsACodegenTranslationInfoAttr;
873878
x += (uintptr_t)&ireeAttributeIsAGPULoweringConfigAttr;
874879
x += (uintptr_t)&ireeAttributeIsAGPUMMAAttr;
875880
x += (uintptr_t)&ireeAttributeIsAGPUMMAIntrinsicAttr;
876881
x += (uintptr_t)&ireeAttributeIsAGPUPipelineOptionsAttr;
877882
x += (uintptr_t)&ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
883+
x += (uintptr_t)&ireeCodegenCompilationInfoAttrGet;
884+
x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetParameters;
885+
x += (uintptr_t)&ireeCodegenCompilationInfoAttrGetTypeID;
878886
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGet;
879887
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
880888
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue;

compiler/src/iree/compiler/API/api_exports.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
; Generated by generate_exports.py: Do not edit.
22
EXPORTS
3+
ireeAttributeIsACodegenCompilationInfoAttr
34
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
45
ireeAttributeIsACodegenTranslationInfoAttr
56
ireeAttributeIsAGPULoweringConfigAttr
67
ireeAttributeIsAGPUMMAAttr
78
ireeAttributeIsAGPUMMAIntrinsicAttr
89
ireeAttributeIsAGPUPipelineOptionsAttr
910
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
11+
ireeCodegenCompilationInfoAttrGet
12+
ireeCodegenCompilationInfoAttrGetParameters
13+
ireeCodegenCompilationInfoAttrGetTypeID
1014
ireeCodegenDispatchLoweringPassPipelineAttrGet
1115
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1216
ireeCodegenDispatchLoweringPassPipelineAttrGetValue

compiler/src/iree/compiler/API/api_exports.ld

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# Generated by generate_exports.py: Do not edit.
22
VER_0 {
33
global:
4+
ireeAttributeIsACodegenCompilationInfoAttr;
45
ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr;
56
ireeAttributeIsACodegenTranslationInfoAttr;
67
ireeAttributeIsAGPULoweringConfigAttr;
78
ireeAttributeIsAGPUMMAAttr;
89
ireeAttributeIsAGPUMMAIntrinsicAttr;
910
ireeAttributeIsAGPUPipelineOptionsAttr;
1011
ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr;
12+
ireeCodegenCompilationInfoAttrGet;
13+
ireeCodegenCompilationInfoAttrGetParameters;
14+
ireeCodegenCompilationInfoAttrGetTypeID;
1115
ireeCodegenDispatchLoweringPassPipelineAttrGet;
1216
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
1317
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;

compiler/src/iree/compiler/API/api_exports.macos.lst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# Generated by generate_exports.py: Do not edit.
2+
_ireeAttributeIsACodegenCompilationInfoAttr
23
_ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr
34
_ireeAttributeIsACodegenTranslationInfoAttr
45
_ireeAttributeIsAGPULoweringConfigAttr
56
_ireeAttributeIsAGPUMMAAttr
67
_ireeAttributeIsAGPUMMAIntrinsicAttr
78
_ireeAttributeIsAGPUPipelineOptionsAttr
89
_ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr
10+
_ireeCodegenCompilationInfoAttrGet
11+
_ireeCodegenCompilationInfoAttrGetParameters
12+
_ireeCodegenCompilationInfoAttrGetTypeID
913
_ireeCodegenDispatchLoweringPassPipelineAttrGet
1014
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1115
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue

0 commit comments

Comments
 (0)