Skip to content

Commit ac29bbb

Browse files
committed
add option to control whether to use conversion attributes
1 parent 78c7b40 commit ac29bbb

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
2222
This is a generic pass to convert to LLVM, it uses the
2323
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
2424
the injection of conversion patterns.
25+
26+
If `use-conversion-attrs` is set to `true`, the pass will look for
27+
`ConvertToLLVMAttrInterface` attributes and use them to further configure
28+
the conversion process. Enabling this option incurs in extra overhead.
2529
}];
2630

2731
let constructor = "mlir::createConvertToLLVMPass()";
2832
let options = [
2933
ListOption<"filterDialects", "filter-dialects", "std::string",
3034
"Test conversion patterns of only the specified dialects">,
35+
Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
36+
"Use op conversion attributes to configure the conversion">,
3137
];
3238
}
3339

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class ConvertToLLVMPass
6363
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
6464
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
6565
interfaces;
66+
std::shared_ptr<const FrozenRewritePatternSet> patterns;
67+
std::shared_ptr<const ConversionTarget> target;
68+
std::shared_ptr<const LLVMTypeConverter> typeConverter;
6669

6770
public:
6871
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -72,8 +75,22 @@ class ConvertToLLVMPass
7275
}
7376

7477
LogicalResult initialize(MLIRContext *context) final {
75-
auto interfaces =
76-
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
78+
std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
79+
std::shared_ptr<ConversionTarget> target;
80+
std::shared_ptr<LLVMTypeConverter> typeConverter;
81+
RewritePatternSet tempPatterns(context);
82+
83+
// Only collect the interfaces if `useConversionAttrs=true` as everything
84+
// else must be initialized in `runOnOperation`.
85+
if (useConversionAttrs) {
86+
interfaces =
87+
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
88+
} else {
89+
target = std::make_shared<ConversionTarget>(*context);
90+
target->addLegalDialect<LLVM::LLVMDialect>();
91+
typeConverter = std::make_shared<LLVMTypeConverter>(context);
92+
}
93+
7794
if (!filterDialects.empty()) {
7895
// Test mode: Populate only patterns from the specified dialects. Produce
7996
// an error if the dialect is not loaded or does not implement the
@@ -88,7 +105,12 @@ class ConvertToLLVMPass
88105
return emitError(UnknownLoc::get(context))
89106
<< "dialect does not implement ConvertToLLVMPatternInterface: "
90107
<< dialectName << "\n";
91-
interfaces->push_back(iface);
108+
if (useConversionAttrs) {
109+
interfaces->push_back(iface);
110+
continue;
111+
}
112+
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
113+
tempPatterns);
92114
}
93115
} else {
94116
// Normal mode: Populate all patterns from all dialects that implement the
@@ -99,15 +121,34 @@ class ConvertToLLVMPass
99121
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
100122
if (!iface)
101123
continue;
102-
interfaces->push_back(iface);
124+
if (useConversionAttrs) {
125+
interfaces->push_back(iface);
126+
continue;
127+
}
128+
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
129+
tempPatterns);
103130
}
104131
}
105132

106-
this->interfaces = interfaces;
133+
if (useConversionAttrs) {
134+
this->interfaces = interfaces;
135+
} else {
136+
this->patterns =
137+
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
138+
this->target = target;
139+
this->typeConverter = typeConverter;
140+
}
107141
return success();
108142
}
109143

110144
void runOnOperation() final {
145+
// Fast path:
146+
if (!useConversionAttrs) {
147+
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
148+
signalPassFailure();
149+
return;
150+
}
151+
// Slow path with conversion attributes.
111152
MLIRContext *context = &getContext();
112153
RewritePatternSet patterns(context);
113154
ConversionTarget target(*context);

mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ void mlir::populateOpConvertToLLVMConversionPatterns(
3535
Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
3636
RewritePatternSet &patterns) {
3737
auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
38+
if (!iface)
39+
iface = op->getParentOfType<ConvertToLLVMOpInterface>();
3840
if (!iface)
3941
return;
4042
SmallVector<ConvertToLLVMAttrInterface, 12> attrs;

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
22

33
// CHECK-LABEL: gpu.module @nvvm_module
44
gpu.module @nvvm_module [#nvvm.target] {

0 commit comments

Comments
 (0)