66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Analysis/DataLayoutAnalysis.h"
910#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1011#include " mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
1112#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -27,7 +28,6 @@ namespace mlir {
2728using namespace mlir ;
2829
2930namespace {
30-
3131// / This DialectExtension can be attached to the context, which will invoke the
3232// / `apply()` method for every loaded dialect. If a dialect implements the
3333// / `ConvertToLLVMPatternInterface` interface, we load dependent dialects
@@ -58,123 +58,188 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
5858 }
5959};
6060
61- // / This is a generic pass to convert to LLVM, it uses the
62- // / `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
63- // / the injection of conversion patterns.
64- class ConvertToLLVMPass
65- : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
66- std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
67- interfaces;
61+ // ===----------------------------------------------------------------------===//
62+ // StaticConvertToLLVM
63+ // ===----------------------------------------------------------------------===//
64+
65+ // / Static implementation of the `convert-to-llvm` pass. This version only looks
66+ // / at dialect interfaces to configure the conversion process.
67+ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
68+ // / Pattern set with conversions to LLVM.
6869 std::shared_ptr<const FrozenRewritePatternSet> patterns;
70+ // / The conversion target.
6971 std::shared_ptr<const ConversionTarget> target;
72+ // / The LLVM type converter.
7073 std::shared_ptr<const LLVMTypeConverter> typeConverter;
74+ using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
7175
72- public:
73- using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
74- void getDependentDialects (DialectRegistry ®istry) const final {
75- registry.insert <LLVM::LLVMDialect>();
76- registry.addExtensions <LoadDependentDialectExtension>();
76+ // / Configure the conversion to LLVM at pass initialization.
77+ LogicalResult initialize () final {
78+ auto target = std::make_shared<ConversionTarget>(*context);
79+ auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
80+ RewritePatternSet tempPatterns (context);
81+ target->addLegalDialect <LLVM::LLVMDialect>();
82+ // Populate the patterns with the dialect interface.
83+ if (failed (visitInterfaces ([&](ConvertToLLVMPatternInterface *iface) {
84+ iface->populateConvertToLLVMConversionPatterns (
85+ *target, *typeConverter, tempPatterns);
86+ })))
87+ return failure ();
88+ this ->patterns =
89+ std::make_unique<FrozenRewritePatternSet>(std::move (tempPatterns));
90+ this ->target = target;
91+ this ->typeConverter = typeConverter;
92+ return success ();
7793 }
7894
79- LogicalResult initialize (MLIRContext *context) final {
80- std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
81- std::shared_ptr<ConversionTarget> target;
82- std::shared_ptr<LLVMTypeConverter> typeConverter;
83- RewritePatternSet tempPatterns (context);
95+ // / Apply the conversion driver.
96+ LogicalResult transform (Operation *op, AnalysisManager manager) const final {
97+ if (failed (applyPartialConversion (op, *target, *patterns)))
98+ return failure ();
99+ return success ();
100+ }
101+ };
84102
85- // Only collect the interfaces if `useConversionAttrs=true` as everything
86- // else must be initialized in `runOnOperation`.
87- if (useConversionAttrs) {
88- interfaces =
89- std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
90- } else {
91- target = std::make_shared<ConversionTarget>(*context);
92- target->addLegalDialect <LLVM::LLVMDialect>();
93- typeConverter = std::make_shared<LLVMTypeConverter>(context);
94- }
103+ // ===----------------------------------------------------------------------===//
104+ // DynamicConvertToLLVM
105+ // ===----------------------------------------------------------------------===//
95106
96- if (!filterDialects.empty ()) {
97- // Test mode: Populate only patterns from the specified dialects. Produce
98- // an error if the dialect is not loaded or does not implement the
99- // interface.
100- for (std::string &dialectName : filterDialects) {
101- Dialect *dialect = context->getLoadedDialect (dialectName);
102- if (!dialect)
103- return emitError (UnknownLoc::get (context))
104- << " dialect not loaded: " << dialectName << " \n " ;
105- auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
106- if (!iface)
107- return emitError (UnknownLoc::get (context))
108- << " dialect does not implement ConvertToLLVMPatternInterface: "
109- << dialectName << " \n " ;
110- if (useConversionAttrs) {
111- interfaces->push_back (iface);
112- continue ;
113- }
114- iface->populateConvertToLLVMConversionPatterns (*target, *typeConverter,
115- tempPatterns);
116- }
117- } else {
118- // Normal mode: Populate all patterns from all dialects that implement the
119- // interface.
120- for (Dialect *dialect : context->getLoadedDialects ()) {
121- // First time we encounter this dialect: if it implements the interface,
122- // let's populate patterns !
123- auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
124- if (!iface)
125- continue ;
126- if (useConversionAttrs) {
107+ // / Dynamic implementation of the `convert-to-llvm` pass. This version inspects
108+ // / the IR to configure the conversion to LLVM.
109+ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
110+ // / A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
111+ // / to partially configure the conversion process.
112+ std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
113+ interfaces;
114+ using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
115+
116+ // / Collect the dialect interfaces used to configure the conversion process.
117+ LogicalResult initialize () final {
118+ auto interfaces =
119+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
120+ // Collect the interfaces.
121+ if (failed (visitInterfaces ([&](ConvertToLLVMPatternInterface *iface) {
127122 interfaces->push_back (iface);
128- continue ;
129- }
130- iface->populateConvertToLLVMConversionPatterns (*target, *typeConverter,
131- tempPatterns);
132- }
133- }
134-
135- if (useConversionAttrs) {
136- this ->interfaces = interfaces;
137- } else {
138- this ->patterns =
139- std::make_unique<FrozenRewritePatternSet>(std::move (tempPatterns));
140- this ->target = target;
141- this ->typeConverter = typeConverter;
142- }
123+ })))
124+ return failure ();
125+ this ->interfaces = interfaces;
143126 return success ();
144127 }
145128
146- void runOnOperation () final {
147- // Fast path:
148- if (!useConversionAttrs) {
149- if (failed (applyPartialConversion (getOperation (), *target, *patterns)))
150- signalPassFailure ();
151- return ;
152- }
153- // Slow path with conversion attributes.
154- MLIRContext *context = &getContext ();
129+ // / Configure the conversion process and apply the conversion driver.
130+ LogicalResult transform (Operation *op, AnalysisManager manager) const final {
155131 RewritePatternSet patterns (context);
156132 ConversionTarget target (*context);
157133 target.addLegalDialect <LLVM::LLVMDialect>();
158- LLVMTypeConverter typeConverter (context);
134+ // Get the data layout analysis.
135+ const auto &dlAnalysis = manager.getAnalysis <DataLayoutAnalysis>();
136+ LLVMTypeConverter typeConverter (context, &dlAnalysis);
159137
160138 // Configure the conversion with dialect level interfaces.
161139 for (ConvertToLLVMPatternInterface *iface : *interfaces)
162140 iface->populateConvertToLLVMConversionPatterns (target, typeConverter,
163141 patterns);
164142
165143 // Configure the conversion attribute interfaces.
166- populateOpConvertToLLVMConversionPatterns (getOperation () , target,
167- typeConverter, patterns);
144+ populateOpConvertToLLVMConversionPatterns (op , target, typeConverter ,
145+ patterns);
168146
169147 // Apply the conversion.
170- if (failed (applyPartialConversion (getOperation (), target,
171- std::move (patterns))))
172- signalPassFailure ();
148+ if (failed (applyPartialConversion (op, target, std::move (patterns))))
149+ return failure ();
150+ return success ();
151+ }
152+ };
153+
154+ // ===----------------------------------------------------------------------===//
155+ // ConvertToLLVMPass
156+ // ===----------------------------------------------------------------------===//
157+
158+ // / This is a generic pass to convert to LLVM, it uses the
159+ // / `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
160+ // / the injection of conversion patterns.
161+ class ConvertToLLVMPass
162+ : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
163+ std::shared_ptr<const ConvertToLLVMPassInterface> impl;
164+
165+ public:
166+ using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
167+ void getDependentDialects (DialectRegistry ®istry) const final {
168+ ConvertToLLVMPassInterface::getDependentDialects (registry);
169+ }
170+
171+ LogicalResult initialize (MLIRContext *context) final {
172+ std::shared_ptr<ConvertToLLVMPassInterface> impl;
173+ // Choose the pass implementation.
174+ if (useDynamic)
175+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
176+ else
177+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
178+ if (failed (impl->initialize ()))
179+ return failure ();
180+ this ->impl = impl;
181+ return success ();
182+ }
183+
184+ void runOnOperation () final {
185+ if (failed (impl->transform (getOperation (), getAnalysisManager ())))
186+ return signalPassFailure ();
173187 }
174188};
175189
176190} // namespace
177191
192+ // ===----------------------------------------------------------------------===//
193+ // ConvertToLLVMPassInterface
194+ // ===----------------------------------------------------------------------===//
195+
196+ ConvertToLLVMPassInterface::ConvertToLLVMPassInterface (
197+ MLIRContext *context, ArrayRef<std::string> filterDialects)
198+ : context(context), filterDialects(filterDialects) {}
199+
200+ void ConvertToLLVMPassInterface::getDependentDialects (
201+ DialectRegistry ®istry) {
202+ registry.insert <LLVM::LLVMDialect>();
203+ registry.addExtensions <LoadDependentDialectExtension>();
204+ }
205+
206+ LogicalResult ConvertToLLVMPassInterface::visitInterfaces (
207+ llvm::function_ref<void (ConvertToLLVMPatternInterface *)> visitor) {
208+ if (!filterDialects.empty ()) {
209+ // Test mode: Populate only patterns from the specified dialects. Produce
210+ // an error if the dialect is not loaded or does not implement the
211+ // interface.
212+ for (StringRef dialectName : filterDialects) {
213+ Dialect *dialect = context->getLoadedDialect (dialectName);
214+ if (!dialect)
215+ return emitError (UnknownLoc::get (context))
216+ << " dialect not loaded: " << dialectName << " \n " ;
217+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
218+ if (!iface)
219+ return emitError (UnknownLoc::get (context))
220+ << " dialect does not implement ConvertToLLVMPatternInterface: "
221+ << dialectName << " \n " ;
222+ visitor (iface);
223+ }
224+ } else {
225+ // Normal mode: Populate all patterns from all dialects that implement the
226+ // interface.
227+ for (Dialect *dialect : context->getLoadedDialects ()) {
228+ // First time we encounter this dialect: if it implements the interface,
229+ // let's populate patterns !
230+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
231+ if (!iface)
232+ continue ;
233+ visitor (iface);
234+ }
235+ }
236+ return success ();
237+ }
238+
239+ // ===----------------------------------------------------------------------===//
240+ // API
241+ // ===----------------------------------------------------------------------===//
242+
178243void mlir::registerConvertToLLVMDependentDialectLoading (
179244 DialectRegistry ®istry) {
180245 registry.addExtensions <LoadDependentDialectExtension>();
0 commit comments