Skip to content

Commit 97e7991

Browse files
committed
add option to control whether to use conversion attributes
1 parent f4c3955 commit 97e7991

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
2222
This is a generic pass to convert to LLVM, it uses the
2323
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
2424
the injection of conversion patterns.
25+
26+
If `use-conversion-attrs` is set to `true`, the pass will look for
27+
`ConvertToLLVMAttrInterface` attributes and use them to further configure
28+
the conversion process. Enabling this option incurs in extra overhead.
2529
}];
2630

2731
let constructor = "mlir::createConvertToLLVMPass()";
2832
let options = [
2933
ListOption<"filterDialects", "filter-dialects", "std::string",
3034
"Test conversion patterns of only the specified dialects">,
35+
Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
36+
"Use op conversion attributes to configure the conversion">,
3137
];
3238
}
3339

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class ConvertToLLVMPass
6565
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
6666
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
6767
interfaces;
68+
std::shared_ptr<const FrozenRewritePatternSet> patterns;
69+
std::shared_ptr<const ConversionTarget> target;
70+
std::shared_ptr<const LLVMTypeConverter> typeConverter;
6871

6972
public:
7073
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -74,8 +77,22 @@ class ConvertToLLVMPass
7477
}
7578

7679
LogicalResult initialize(MLIRContext *context) final {
77-
auto interfaces =
78-
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
80+
std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
81+
std::shared_ptr<ConversionTarget> target;
82+
std::shared_ptr<LLVMTypeConverter> typeConverter;
83+
RewritePatternSet tempPatterns(context);
84+
85+
// Only collect the interfaces if `useConversionAttrs=true` as everything
86+
// else must be initialized in `runOnOperation`.
87+
if (useConversionAttrs) {
88+
interfaces =
89+
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
90+
} else {
91+
target = std::make_shared<ConversionTarget>(*context);
92+
target->addLegalDialect<LLVM::LLVMDialect>();
93+
typeConverter = std::make_shared<LLVMTypeConverter>(context);
94+
}
95+
7996
if (!filterDialects.empty()) {
8097
// Test mode: Populate only patterns from the specified dialects. Produce
8198
// an error if the dialect is not loaded or does not implement the
@@ -90,7 +107,12 @@ class ConvertToLLVMPass
90107
return emitError(UnknownLoc::get(context))
91108
<< "dialect does not implement ConvertToLLVMPatternInterface: "
92109
<< dialectName << "\n";
93-
interfaces->push_back(iface);
110+
if (useConversionAttrs) {
111+
interfaces->push_back(iface);
112+
continue;
113+
}
114+
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
115+
tempPatterns);
94116
}
95117
} else {
96118
// Normal mode: Populate all patterns from all dialects that implement the
@@ -101,15 +123,34 @@ class ConvertToLLVMPass
101123
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
102124
if (!iface)
103125
continue;
104-
interfaces->push_back(iface);
126+
if (useConversionAttrs) {
127+
interfaces->push_back(iface);
128+
continue;
129+
}
130+
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
131+
tempPatterns);
105132
}
106133
}
107134

108-
this->interfaces = interfaces;
135+
if (useConversionAttrs) {
136+
this->interfaces = interfaces;
137+
} else {
138+
this->patterns =
139+
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
140+
this->target = target;
141+
this->typeConverter = typeConverter;
142+
}
109143
return success();
110144
}
111145

112146
void runOnOperation() final {
147+
// Fast path:
148+
if (!useConversionAttrs) {
149+
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
150+
signalPassFailure();
151+
return;
152+
}
153+
// Slow path with conversion attributes.
113154
MLIRContext *context = &getContext();
114155
RewritePatternSet patterns(context);
115156
ConversionTarget target(*context);

mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ void mlir::populateOpConvertToLLVMConversionPatterns(
3535
Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter,
3636
RewritePatternSet &patterns) {
3737
auto iface = dyn_cast<ConvertToLLVMOpInterface>(op);
38+
if (!iface)
39+
iface = op->getParentOfType<ConvertToLLVMOpInterface>();
3840
if (!iface)
3941
return;
4042
SmallVector<ConvertToLLVMAttrInterface, 12> attrs;

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm))" | FileCheck %s
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
22

33
// CHECK-LABEL: gpu.module @nvvm_module
44
gpu.module @nvvm_module [#nvvm.target] {

0 commit comments

Comments
 (0)