Skip to content

Commit 65463e9

Browse files
committed
[mlir][LLVM] Add the ConvertToLLVMAttrInterface and ConvertToLLVMOpInterface interfaces
This patch adds the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available. The `ConvertToLLVMAttrInterface` interfaces allows attributes to configure conversion to LLVM, including the conversion target, LLVM type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure conversion to LLVM. The `ConvertToLLVMOpInterface` interface collects all convert to LLVM attributes stored in an operation. Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
1 parent ded35c0 commit 65463e9

File tree

14 files changed

+335
-40
lines changed

14 files changed

+335
-40
lines changed

mlir/include/mlir/Conversion/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
66
add_public_tablegen_target(MLIRConversionPassIncGen)
77

88
add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
9+
10+
add_subdirectory(ConvertToLLVM)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
set(LLVM_TARGET_DEFINITIONS ToLLVMInterface.td)
2+
mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls)
3+
mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs)
4+
mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls)
5+
mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs)
6+
add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
7+
add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen)

mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/IR/DialectInterface.h"
1313
#include "mlir/IR/MLIRContext.h"
14+
#include "mlir/IR/OpDefinition.h"
1415

1516
namespace mlir {
1617
class ConversionTarget;
@@ -50,6 +51,18 @@ void populateConversionTargetFromOperation(Operation *op,
5051
LLVMTypeConverter &typeConverter,
5152
RewritePatternSet &patterns);
5253

54+
/// Helper function for populating LLVM conversion patterns. If `op` implements
55+
/// the `ConvertToLLVMOpInterface` interface, then the LLVM conversion pattern
56+
/// attributes provided by the interface will be used to configure the
57+
/// conversion target, type converter, and the pattern set.
58+
void populateOpConvertToLLVMConversionPatterns(Operation *op,
59+
ConversionTarget &target,
60+
LLVMTypeConverter &typeConverter,
61+
RewritePatternSet &patterns);
5362
} // namespace mlir
5463

64+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
65+
66+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.h.inc"
67+
5568
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
2+
//===- ToLLVMInterface.td - Conversion to LLVM interfaces -----*- tablegen -*-===//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines interfaces for managing transformations, including populating
10+
// pattern rewrites.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
15+
#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
16+
17+
include "mlir/IR/OpBase.td"
18+
19+
//===----------------------------------------------------------------------===//
20+
// Attribute interface
21+
//===----------------------------------------------------------------------===//
22+
23+
def ConvertToLLVMAttrInterface :
24+
AttrInterface<"ConvertToLLVMAttrInterface"> {
25+
let description = [{
26+
The `ConvertToLLVMAttrInterface` attribute interfaces allows using
27+
attributes to configure the convert to LLVM infrastructure, this includes:
28+
- The conversion target.
29+
- The LLVM type converter.
30+
- The pattern set.
31+
32+
This interface permits fined grained configuration of the `convert-to-llvm`
33+
process. For example, attributes with target information like
34+
`#nvvm.target` or `#rodcl.target` can leverage this interface for populating
35+
patterns specific to a particular target.
36+
}];
37+
let cppNamespace = "::mlir";
38+
let methods = [
39+
InterfaceMethod<
40+
/*desc=*/[{
41+
Populate the dialect conversion target, type converter and pattern set.
42+
}],
43+
/*retTy=*/"void",
44+
/*methodName=*/"populateConvertToLLVMConversionPatterns",
45+
/*args=*/(ins "::mlir::ConversionTarget&":$target,
46+
"::mlir::LLVMTypeConverter&":$typeConverter,
47+
"::mlir::RewritePatternSet&":$patternSet)>
48+
];
49+
}
50+
51+
//===----------------------------------------------------------------------===//
52+
// Op interface
53+
//===----------------------------------------------------------------------===//
54+
55+
def ConvertToLLVMOpInterface : OpInterface<"ConvertToLLVMOpInterface"> {
56+
let description = [{
57+
Interface for collecting all convert to LLVM attributes stored in an
58+
operation. See `ConvertToLLVMAttrInterface` for more information on these
59+
attributes.
60+
}];
61+
let cppNamespace = "::mlir";
62+
let methods = [
63+
InterfaceMethod<
64+
/*desc=*/[{
65+
Populate the provided vector with a list of convert to LLVM attributes
66+
to apply.
67+
}],
68+
/*retTy=*/"void",
69+
/*methodName=*/"getConvertToLLVMConversionAttrs",
70+
/*args=*/(ins
71+
"::llvm::SmallVectorImpl<::mlir::ConvertToLLVMAttrInterface>&":$attrs)
72+
>
73+
];
74+
}
75+
76+
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- GPUToLLVM.h - Convert GPU to LLVM dialect ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This files declares registration functions for converting GPU to LLVM.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
14+
#define MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
namespace gpu {
19+
/// Registers the `ConvertToLLVMOpInterface` interface on the `gpu::GPUModuleOP`
20+
/// operation.
21+
void registerConvertGpuToLLVMInterface(DialectRegistry &registry);
22+
} // namespace gpu
23+
} // namespace mlir
24+
25+
#endif // MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This files declares registration functions for converting GPU to NVVM.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
14+
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
namespace NVVM {
19+
/// Registers the `ConvertToLLVMAttrInterface` interface on the
20+
/// `NVVM::NVVMTargetAttr` attribute. This interface populates the conversion
21+
/// target, LLVM type converter, and pattern set for converting GPU operations
22+
/// to NVVM.
23+
void registerConvertGpuToNVVMInterface(DialectRegistry &registry);
24+
} // namespace NVVM
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
3131
/// Configure target to convert from the GPU dialect to NVVM.
3232
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
3333

34+
/// Configure the LLVM type convert to convert types and address spaces from the
35+
/// GPU dialect to NVVM.
36+
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
37+
3438
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
3539
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
3640
RewritePatternSet &patterns);

mlir/include/mlir/InitAllExtensions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
1919
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
2020
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21+
#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
22+
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
2123
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
2224
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2325
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -65,6 +67,8 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6567
registerConvertMemRefToLLVMInterface(registry);
6668
registerConvertNVVMToLLVMInterface(registry);
6769
ub::registerConvertUBToLLVMInterface(registry);
70+
gpu::registerConvertGpuToLLVMInterface(registry);
71+
NVVM::registerConvertGpuToNVVMInterface(registry);
6872

6973
// Register all transform dialect extensions.
7074
affine::registerTransformDialectExtension(registry);

mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMInterface
77
ToLLVMInterface.cpp
88

99
DEPENDS
10+
MLIRConvertToLLVMInterfaceIncGen
1011

1112
LINK_LIBS PUBLIC
1213
MLIRIR

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
6161
/// the injection of conversion patterns.
6262
class ConvertToLLVMPass
6363
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
64-
std::shared_ptr<const FrozenRewritePatternSet> patterns;
65-
std::shared_ptr<const ConversionTarget> target;
66-
std::shared_ptr<const LLVMTypeConverter> typeConverter;
64+
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
65+
interfaces;
6766

6867
public:
6968
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -73,11 +72,8 @@ class ConvertToLLVMPass
7372
}
7473

7574
LogicalResult initialize(MLIRContext *context) final {
76-
RewritePatternSet tempPatterns(context);
77-
auto target = std::make_shared<ConversionTarget>(*context);
78-
target->addLegalDialect<LLVM::LLVMDialect>();
79-
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
80-
75+
auto interfaces =
76+
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
8177
if (!filterDialects.empty()) {
8278
// Test mode: Populate only patterns from the specified dialects. Produce
8379
// an error if the dialect is not loaded or does not implement the
@@ -92,8 +88,7 @@ class ConvertToLLVMPass
9288
return emitError(UnknownLoc::get(context))
9389
<< "dialect does not implement ConvertToLLVMPatternInterface: "
9490
<< dialectName << "\n";
95-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
96-
tempPatterns);
91+
interfaces->push_back(iface);
9792
}
9893
} else {
9994
// Normal mode: Populate all patterns from all dialects that implement the
@@ -104,20 +99,33 @@ class ConvertToLLVMPass
10499
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
105100
if (!iface)
106101
continue;
107-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
108-
tempPatterns);
102+
interfaces->push_back(iface);
109103
}
110104
}
111105

112-
this->patterns =
113-
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
114-
this->target = target;
115-
this->typeConverter = typeConverter;
106+
this->interfaces = interfaces;
116107
return success();
117108
}
118109

119110
void runOnOperation() final {
120-
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
111+
MLIRContext *context = &getContext();
112+
RewritePatternSet patterns(context);
113+
ConversionTarget target(*context);
114+
target.addLegalDialect<LLVM::LLVMDialect>();
115+
LLVMTypeConverter typeConverter(context);
116+
117+
// Configure the conversion with dialect level interfaces.
118+
for (ConvertToLLVMPatternInterface *iface : *interfaces)
119+
iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
120+
patterns);
121+
122+
// Configure the conversion attribute interfaces.
123+
populateOpConvertToLLVMConversionPatterns(getOperation(), target,
124+
typeConverter, patterns);
125+
126+
// Apply the conversion.
127+
if (failed(applyPartialConversion(getOperation(), target,
128+
std::move(patterns))))
121129
signalPassFailure();
122130
}
123131
};

0 commit comments

Comments
 (0)