Skip to content

Commit f550639

Browse files
committed
address reviewer comments
1 parent 97e7991 commit f550639

File tree

4 files changed

+198
-94
lines changed

4 files changed

+198
-94
lines changed

mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class LLVMTypeConverter;
1919
class MLIRContext;
2020
class Operation;
2121
class RewritePatternSet;
22+
class AnalysisManager;
2223

2324
/// Base class for dialect interfaces providing translation to LLVM IR.
2425
/// Dialects that can be translated should provide an implementation of this
@@ -59,6 +60,42 @@ void populateOpConvertToLLVMConversionPatterns(Operation *op,
5960
ConversionTarget &target,
6061
LLVMTypeConverter &typeConverter,
6162
RewritePatternSet &patterns);
63+
64+
/// Base class for creating the internal implementation of `convert-to-llvm`
65+
/// passes.
66+
class ConvertToLLVMPassInterface {
67+
public:
68+
ConvertToLLVMPassInterface(MLIRContext *context,
69+
ArrayRef<std::string> filterDialects);
70+
virtual ~ConvertToLLVMPassInterface() = default;
71+
72+
/// Get the dependent dialects used by `convert-to-llvm`.
73+
static void getDependentDialects(DialectRegistry &registry);
74+
75+
/// Initialize the internal state of the `convert-to-llvm` pass
76+
/// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
77+
/// This method returns whether the initialization process failed.
78+
virtual LogicalResult initialize() = 0;
79+
80+
/// Transform `op` to LLVM with the conversions available in the pass. The
81+
/// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
82+
/// to further configure the conversion process. This method is invoked by
83+
/// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
84+
/// transformation process failed.
85+
virtual LogicalResult transform(Operation *op,
86+
AnalysisManager manager) const = 0;
87+
88+
protected:
89+
/// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
90+
/// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
91+
/// then `visitor` is invoked only with the dialects in the `filterDialects`
92+
/// list.
93+
LogicalResult visitInterfaces(
94+
llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
95+
MLIRContext *context;
96+
/// List of dialects names to use as filters.
97+
ArrayRef<std::string> filterDialects;
98+
};
6299
} // namespace mlir
63100

64101
#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
2323
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
2424
the injection of conversion patterns.
2525

26-
If `use-conversion-attrs` is set to `true`, the pass will look for
26+
If `dynamic` is set to `true`, the pass will look for
2727
`ConvertToLLVMAttrInterface` attributes and use them to further configure
28-
the conversion process. Enabling this option incurs in extra overhead.
28+
the conversion process. This option also uses the `DataLayoutAnalysis`
29+
analysis to configure the type converter. Enabling this option incurs in
30+
extra overhead.
2931
}];
3032

3133
let constructor = "mlir::createConvertToLLVMPass()";
3234
let options = [
3335
ListOption<"filterDialects", "filter-dialects", "std::string",
3436
"Test conversion patterns of only the specified dialects">,
35-
Option<"useConversionAttrs", "use-conversion-attrs", "bool", "false",
37+
Option<"useDynamic", "dynamic", "bool", "false",
3638
"Use op conversion attributes to configure the conversion">,
3739
];
3840
}

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 154 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Analysis/DataLayoutAnalysis.h"
910
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1011
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
1112
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -27,7 +28,6 @@ namespace mlir {
2728
using namespace mlir;
2829

2930
namespace {
30-
3131
/// This DialectExtension can be attached to the context, which will invoke the
3232
/// `apply()` method for every loaded dialect. If a dialect implements the
3333
/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
@@ -58,123 +58,188 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
5858
}
5959
};
6060

61-
/// This is a generic pass to convert to LLVM, it uses the
62-
/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
63-
/// the injection of conversion patterns.
64-
class ConvertToLLVMPass
65-
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
66-
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
67-
interfaces;
61+
//===----------------------------------------------------------------------===//
62+
// StaticConvertToLLVM
63+
//===----------------------------------------------------------------------===//
64+
65+
/// Static implementation of the `convert-to-llvm` pass. This version only looks
66+
/// at dialect interfaces to configure the conversion process.
67+
struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
68+
/// Pattern set with conversions to LLVM.
6869
std::shared_ptr<const FrozenRewritePatternSet> patterns;
70+
/// The conversion target.
6971
std::shared_ptr<const ConversionTarget> target;
72+
/// The LLVM type converter.
7073
std::shared_ptr<const LLVMTypeConverter> typeConverter;
74+
using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
7175

72-
public:
73-
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
74-
void getDependentDialects(DialectRegistry &registry) const final {
75-
registry.insert<LLVM::LLVMDialect>();
76-
registry.addExtensions<LoadDependentDialectExtension>();
76+
/// Configure the conversion to LLVM at pass initialization.
77+
LogicalResult initialize() final {
78+
auto target = std::make_shared<ConversionTarget>(*context);
79+
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
80+
RewritePatternSet tempPatterns(context);
81+
target->addLegalDialect<LLVM::LLVMDialect>();
82+
// Populate the patterns with the dialect interface.
83+
if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
84+
iface->populateConvertToLLVMConversionPatterns(
85+
*target, *typeConverter, tempPatterns);
86+
})))
87+
return failure();
88+
this->patterns =
89+
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
90+
this->target = target;
91+
this->typeConverter = typeConverter;
92+
return success();
7793
}
7894

79-
LogicalResult initialize(MLIRContext *context) final {
80-
std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
81-
std::shared_ptr<ConversionTarget> target;
82-
std::shared_ptr<LLVMTypeConverter> typeConverter;
83-
RewritePatternSet tempPatterns(context);
95+
/// Apply the conversion driver.
96+
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
97+
if (failed(applyPartialConversion(op, *target, *patterns)))
98+
return failure();
99+
return success();
100+
}
101+
};
84102

85-
// Only collect the interfaces if `useConversionAttrs=true` as everything
86-
// else must be initialized in `runOnOperation`.
87-
if (useConversionAttrs) {
88-
interfaces =
89-
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
90-
} else {
91-
target = std::make_shared<ConversionTarget>(*context);
92-
target->addLegalDialect<LLVM::LLVMDialect>();
93-
typeConverter = std::make_shared<LLVMTypeConverter>(context);
94-
}
103+
//===----------------------------------------------------------------------===//
104+
// DynamicConvertToLLVM
105+
//===----------------------------------------------------------------------===//
95106

96-
if (!filterDialects.empty()) {
97-
// Test mode: Populate only patterns from the specified dialects. Produce
98-
// an error if the dialect is not loaded or does not implement the
99-
// interface.
100-
for (std::string &dialectName : filterDialects) {
101-
Dialect *dialect = context->getLoadedDialect(dialectName);
102-
if (!dialect)
103-
return emitError(UnknownLoc::get(context))
104-
<< "dialect not loaded: " << dialectName << "\n";
105-
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
106-
if (!iface)
107-
return emitError(UnknownLoc::get(context))
108-
<< "dialect does not implement ConvertToLLVMPatternInterface: "
109-
<< dialectName << "\n";
110-
if (useConversionAttrs) {
111-
interfaces->push_back(iface);
112-
continue;
113-
}
114-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
115-
tempPatterns);
116-
}
117-
} else {
118-
// Normal mode: Populate all patterns from all dialects that implement the
119-
// interface.
120-
for (Dialect *dialect : context->getLoadedDialects()) {
121-
// First time we encounter this dialect: if it implements the interface,
122-
// let's populate patterns !
123-
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
124-
if (!iface)
125-
continue;
126-
if (useConversionAttrs) {
107+
/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
108+
/// the IR to configure the conversion to LLVM.
109+
struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
110+
/// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
111+
/// to partially configure the conversion process.
112+
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
113+
interfaces;
114+
using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
115+
116+
/// Collect the dialect interfaces used to configure the conversion process.
117+
LogicalResult initialize() final {
118+
auto interfaces =
119+
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
120+
// Collect the interfaces.
121+
if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
127122
interfaces->push_back(iface);
128-
continue;
129-
}
130-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
131-
tempPatterns);
132-
}
133-
}
134-
135-
if (useConversionAttrs) {
136-
this->interfaces = interfaces;
137-
} else {
138-
this->patterns =
139-
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
140-
this->target = target;
141-
this->typeConverter = typeConverter;
142-
}
123+
})))
124+
return failure();
125+
this->interfaces = interfaces;
143126
return success();
144127
}
145128

146-
void runOnOperation() final {
147-
// Fast path:
148-
if (!useConversionAttrs) {
149-
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
150-
signalPassFailure();
151-
return;
152-
}
153-
// Slow path with conversion attributes.
154-
MLIRContext *context = &getContext();
129+
/// Configure the conversion process and apply the conversion driver.
130+
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
155131
RewritePatternSet patterns(context);
156132
ConversionTarget target(*context);
157133
target.addLegalDialect<LLVM::LLVMDialect>();
158-
LLVMTypeConverter typeConverter(context);
134+
// Get the data layout analysis.
135+
const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
136+
LLVMTypeConverter typeConverter(context, &dlAnalysis);
159137

160138
// Configure the conversion with dialect level interfaces.
161139
for (ConvertToLLVMPatternInterface *iface : *interfaces)
162140
iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
163141
patterns);
164142

165143
// Configure the conversion attribute interfaces.
166-
populateOpConvertToLLVMConversionPatterns(getOperation(), target,
167-
typeConverter, patterns);
144+
populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
145+
patterns);
168146

169147
// Apply the conversion.
170-
if (failed(applyPartialConversion(getOperation(), target,
171-
std::move(patterns))))
172-
signalPassFailure();
148+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
149+
return failure();
150+
return success();
151+
}
152+
};
153+
154+
//===----------------------------------------------------------------------===//
155+
// ConvertToLLVMPass
156+
//===----------------------------------------------------------------------===//
157+
158+
/// This is a generic pass to convert to LLVM, it uses the
159+
/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
160+
/// the injection of conversion patterns.
161+
class ConvertToLLVMPass
162+
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
163+
std::shared_ptr<const ConvertToLLVMPassInterface> impl;
164+
165+
public:
166+
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
167+
void getDependentDialects(DialectRegistry &registry) const final {
168+
ConvertToLLVMPassInterface::getDependentDialects(registry);
169+
}
170+
171+
LogicalResult initialize(MLIRContext *context) final {
172+
std::shared_ptr<ConvertToLLVMPassInterface> impl;
173+
// Choose the pass implementation.
174+
if (useDynamic)
175+
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
176+
else
177+
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
178+
if (failed(impl->initialize()))
179+
return failure();
180+
this->impl = impl;
181+
return success();
182+
}
183+
184+
void runOnOperation() final {
185+
if (failed(impl->transform(getOperation(), getAnalysisManager())))
186+
return signalPassFailure();
173187
}
174188
};
175189

176190
} // namespace
177191

192+
//===----------------------------------------------------------------------===//
193+
// ConvertToLLVMPassInterface
194+
//===----------------------------------------------------------------------===//
195+
196+
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
197+
MLIRContext *context, ArrayRef<std::string> filterDialects)
198+
: context(context), filterDialects(filterDialects) {}
199+
200+
void ConvertToLLVMPassInterface::getDependentDialects(
201+
DialectRegistry &registry) {
202+
registry.insert<LLVM::LLVMDialect>();
203+
registry.addExtensions<LoadDependentDialectExtension>();
204+
}
205+
206+
LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
207+
llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
208+
if (!filterDialects.empty()) {
209+
// Test mode: Populate only patterns from the specified dialects. Produce
210+
// an error if the dialect is not loaded or does not implement the
211+
// interface.
212+
for (StringRef dialectName : filterDialects) {
213+
Dialect *dialect = context->getLoadedDialect(dialectName);
214+
if (!dialect)
215+
return emitError(UnknownLoc::get(context))
216+
<< "dialect not loaded: " << dialectName << "\n";
217+
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
218+
if (!iface)
219+
return emitError(UnknownLoc::get(context))
220+
<< "dialect does not implement ConvertToLLVMPatternInterface: "
221+
<< dialectName << "\n";
222+
visitor(iface);
223+
}
224+
} else {
225+
// Normal mode: Populate all patterns from all dialects that implement the
226+
// interface.
227+
for (Dialect *dialect : context->getLoadedDialects()) {
228+
// First time we encounter this dialect: if it implements the interface,
229+
// let's populate patterns !
230+
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
231+
if (!iface)
232+
continue;
233+
visitor(iface);
234+
}
235+
}
236+
return success();
237+
}
238+
239+
//===----------------------------------------------------------------------===//
240+
// API
241+
//===----------------------------------------------------------------------===//
242+
178243
void mlir::registerConvertToLLVMDependentDialectLoading(
179244
DialectRegistry &registry) {
180245
registry.addExtensions<LoadDependentDialectExtension>();

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{use-conversion-attrs=true}))" | FileCheck %s
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true}))" | FileCheck %s
22

33
// CHECK-LABEL: gpu.module @nvvm_module
44
gpu.module @nvvm_module [#nvvm.target] {
@@ -7,7 +7,7 @@ gpu.module @nvvm_module [#nvvm.target] {
77
// CHECK: = nvvm.read.ptx.sreg.tid.x : i32
88
// CHECK: = llvm.sext %{{.*}} : i32 to i64
99
%tIdX = gpu.thread_id x
10-
// CHECK: = nvvm.read.ptx.sreg.laneid : i32
10+
// CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
1111
// CHECK: = llvm.sext %{{.*}} : i32 to i64
1212
%laneId = gpu.lane_id
1313
%sum = index.add %tIdX, %laneId

0 commit comments

Comments
 (0)