Skip to content

Commit f4c3955

Browse files
committed
[mlir][LLVM] Add the ConvertToLLVMAttrInterface and ConvertToLLVMOpInterface interfaces
This patch adds the `ConvertToLLVMAttrInterface` and `ConvertToLLVMOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available. The `ConvertToLLVMAttrInterface` interfaces allows attributes to configure conversion to LLVM, including the conversion target, LLVM type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure conversion to LLVM. The `ConvertToLLVMOpInterface` interface collects all convert to LLVM attributes stored in an operation. Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
1 parent 14b9ca3 commit f4c3955

File tree

14 files changed

+335
-40
lines changed

14 files changed

+335
-40
lines changed

mlir/include/mlir/Conversion/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
66
add_public_tablegen_target(MLIRConversionPassIncGen)
77

88
add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
9+
10+
add_subdirectory(ConvertToLLVM)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
set(LLVM_TARGET_DEFINITIONS ToLLVMInterface.td)
2+
mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls)
3+
mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs)
4+
mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls)
5+
mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs)
6+
add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
7+
add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen)

mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/IR/DialectInterface.h"
1313
#include "mlir/IR/MLIRContext.h"
14+
#include "mlir/IR/OpDefinition.h"
1415

1516
namespace mlir {
1617
class ConversionTarget;
@@ -50,6 +51,18 @@ void populateConversionTargetFromOperation(Operation *op,
5051
LLVMTypeConverter &typeConverter,
5152
RewritePatternSet &patterns);
5253

54+
/// Helper function for populating LLVM conversion patterns. If `op` implements
55+
/// the `ConvertToLLVMOpInterface` interface, then the LLVM conversion pattern
56+
/// attributes provided by the interface will be used to configure the
57+
/// conversion target, type converter, and the pattern set.
58+
void populateOpConvertToLLVMConversionPatterns(Operation *op,
59+
ConversionTarget &target,
60+
LLVMTypeConverter &typeConverter,
61+
RewritePatternSet &patterns);
5362
} // namespace mlir
5463

64+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc"
65+
66+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.h.inc"
67+
5568
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
2+
//===- ToLLVMInterface.td - Conversion to LLVM interfaces -----*- tablegen -*-===//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines interfaces for managing transformations, including populating
10+
// pattern rewrites.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
15+
#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
16+
17+
include "mlir/IR/OpBase.td"
18+
19+
//===----------------------------------------------------------------------===//
20+
// Attribute interface
21+
//===----------------------------------------------------------------------===//
22+
23+
def ConvertToLLVMAttrInterface :
24+
AttrInterface<"ConvertToLLVMAttrInterface"> {
25+
let description = [{
26+
The `ConvertToLLVMAttrInterface` attribute interfaces allows using
27+
attributes to configure the convert to LLVM infrastructure, this includes:
28+
- The conversion target.
29+
- The LLVM type converter.
30+
- The pattern set.
31+
32+
This interface permits fined grained configuration of the `convert-to-llvm`
33+
process. For example, attributes with target information like
34+
`#nvvm.target` or `#rodcl.target` can leverage this interface for populating
35+
patterns specific to a particular target.
36+
}];
37+
let cppNamespace = "::mlir";
38+
let methods = [
39+
InterfaceMethod<
40+
/*desc=*/[{
41+
Populate the dialect conversion target, type converter and pattern set.
42+
}],
43+
/*retTy=*/"void",
44+
/*methodName=*/"populateConvertToLLVMConversionPatterns",
45+
/*args=*/(ins "::mlir::ConversionTarget&":$target,
46+
"::mlir::LLVMTypeConverter&":$typeConverter,
47+
"::mlir::RewritePatternSet&":$patternSet)>
48+
];
49+
}
50+
51+
//===----------------------------------------------------------------------===//
52+
// Op interface
53+
//===----------------------------------------------------------------------===//
54+
55+
def ConvertToLLVMOpInterface : OpInterface<"ConvertToLLVMOpInterface"> {
56+
let description = [{
57+
Interface for collecting all convert to LLVM attributes stored in an
58+
operation. See `ConvertToLLVMAttrInterface` for more information on these
59+
attributes.
60+
}];
61+
let cppNamespace = "::mlir";
62+
let methods = [
63+
InterfaceMethod<
64+
/*desc=*/[{
65+
Populate the provided vector with a list of convert to LLVM attributes
66+
to apply.
67+
}],
68+
/*retTy=*/"void",
69+
/*methodName=*/"getConvertToLLVMConversionAttrs",
70+
/*args=*/(ins
71+
"::llvm::SmallVectorImpl<::mlir::ConvertToLLVMAttrInterface>&":$attrs)
72+
>
73+
];
74+
}
75+
76+
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- GPUToLLVM.h - Convert GPU to LLVM dialect ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This files declares registration functions for converting GPU to LLVM.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
14+
#define MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
namespace gpu {
19+
/// Registers the `ConvertToLLVMOpInterface` interface on the `gpu::GPUModuleOP`
20+
/// operation.
21+
void registerConvertGpuToLLVMInterface(DialectRegistry &registry);
22+
} // namespace gpu
23+
} // namespace mlir
24+
25+
#endif // MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This files declares registration functions for converting GPU to NVVM.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
14+
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
namespace NVVM {
19+
/// Registers the `ConvertToLLVMAttrInterface` interface on the
20+
/// `NVVM::NVVMTargetAttr` attribute. This interface populates the conversion
21+
/// target, LLVM type converter, and pattern set for converting GPU operations
22+
/// to NVVM.
23+
void registerConvertGpuToNVVMInterface(DialectRegistry &registry);
24+
} // namespace NVVM
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
3131
/// Configure target to convert from the GPU dialect to NVVM.
3232
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
3333

34+
/// Configure the LLVM type convert to convert types and address spaces from the
35+
/// GPU dialect to NVVM.
36+
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
37+
3438
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
3539
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
3640
RewritePatternSet &patterns);

mlir/include/mlir/InitAllExtensions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
1919
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
2020
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21+
#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
22+
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
2123
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
2224
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2325
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -72,6 +74,8 @@ inline void registerAllExtensions(DialectRegistry &registry) {
7274
registerConvertOpenMPToLLVMInterface(registry);
7375
ub::registerConvertUBToLLVMInterface(registry);
7476
registerConvertAMXToLLVMInterface(registry);
77+
gpu::registerConvertGpuToLLVMInterface(registry);
78+
NVVM::registerConvertGpuToNVVMInterface(registry);
7579

7680
// Register all transform dialect extensions.
7781
affine::registerTransformDialectExtension(registry);

mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMInterface
77
ToLLVMInterface.cpp
88

99
DEPENDS
10+
MLIRConvertToLLVMInterfaceIncGen
1011

1112
LINK_LIBS PUBLIC
1213
MLIRIR

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
6363
/// the injection of conversion patterns.
6464
class ConvertToLLVMPass
6565
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
66-
std::shared_ptr<const FrozenRewritePatternSet> patterns;
67-
std::shared_ptr<const ConversionTarget> target;
68-
std::shared_ptr<const LLVMTypeConverter> typeConverter;
66+
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
67+
interfaces;
6968

7069
public:
7170
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -75,11 +74,8 @@ class ConvertToLLVMPass
7574
}
7675

7776
LogicalResult initialize(MLIRContext *context) final {
78-
RewritePatternSet tempPatterns(context);
79-
auto target = std::make_shared<ConversionTarget>(*context);
80-
target->addLegalDialect<LLVM::LLVMDialect>();
81-
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
82-
77+
auto interfaces =
78+
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
8379
if (!filterDialects.empty()) {
8480
// Test mode: Populate only patterns from the specified dialects. Produce
8581
// an error if the dialect is not loaded or does not implement the
@@ -94,8 +90,7 @@ class ConvertToLLVMPass
9490
return emitError(UnknownLoc::get(context))
9591
<< "dialect does not implement ConvertToLLVMPatternInterface: "
9692
<< dialectName << "\n";
97-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
98-
tempPatterns);
93+
interfaces->push_back(iface);
9994
}
10095
} else {
10196
// Normal mode: Populate all patterns from all dialects that implement the
@@ -106,20 +101,33 @@ class ConvertToLLVMPass
106101
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
107102
if (!iface)
108103
continue;
109-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
110-
tempPatterns);
104+
interfaces->push_back(iface);
111105
}
112106
}
113107

114-
this->patterns =
115-
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
116-
this->target = target;
117-
this->typeConverter = typeConverter;
108+
this->interfaces = interfaces;
118109
return success();
119110
}
120111

121112
void runOnOperation() final {
122-
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
113+
MLIRContext *context = &getContext();
114+
RewritePatternSet patterns(context);
115+
ConversionTarget target(*context);
116+
target.addLegalDialect<LLVM::LLVMDialect>();
117+
LLVMTypeConverter typeConverter(context);
118+
119+
// Configure the conversion with dialect level interfaces.
120+
for (ConvertToLLVMPatternInterface *iface : *interfaces)
121+
iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
122+
patterns);
123+
124+
// Configure the conversion attribute interfaces.
125+
populateOpConvertToLLVMConversionPatterns(getOperation(), target,
126+
typeConverter, patterns);
127+
128+
// Apply the conversion.
129+
if (failed(applyPartialConversion(getOperation(), target,
130+
std::move(patterns))))
123131
signalPassFailure();
124132
}
125133
};

0 commit comments

Comments
 (0)