Skip to content

Commit ece4a8f

Browse files
committed
[mlir] Add the TransformsInterfaces for configuring transformations
This patch adds the `ConversionPatternsAttrInterface` and `OpWithTransformAttrsOpInterface` interfaces. It also modifies the `convert-to-llvm` pass to use these interfaces when available. The `ConversionPatternsAttrInterface` allows attributes to configure the dialect conversion infrastructure, including the conversion target, type converter, and populating conversion patterns. See the `NVVMTargetAttr` implementation of this interface for an example of how this interface can be used to configure dialect conversion. The `OpWithTransformAttrsOpInterface` allows interacting with transforms attributes. These attributes allow configuring transformations like dialect conversion with information present in the IR. Finally, the `convert-to-llvm` pass was modified to use these interfaces when available. This allows applying `convert-to-llvm` to GPU modules and letting the `NVVMTargetAttr` decide which patterns to populate.
1 parent ded35c0 commit ece4a8f

File tree

18 files changed

+460
-41
lines changed

18 files changed

+460
-41
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This files declares registration functions for converting GPU to NVVM.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
14+
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H
15+
16+
namespace mlir {
17+
class DialectRegistry;
18+
namespace NVVM {
19+
/// Registers the `ConversionPatternsAttrInterface` interface on the
20+
/// `NVVM::NVVMTargetAttr`. This interface populates the conversion target,
21+
/// LLVM type converter, and pattern set for converting GPU operations to NVVM.
22+
void registerConvertGpuToNVVMAttrInterface(DialectRegistry &registry);
23+
} // namespace NVVM
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
3131
/// Configure target to convert from the GPU dialect to NVVM.
3232
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
3333

34+
/// Configure the LLVM type convert to convert types and address spaces from the
35+
/// GPU dialect to NVVM.
36+
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
37+
3438
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
3539
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
3640
RewritePatternSet &patterns);
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//===- ConversionAttrOptions.h - LLVM conversion options --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares convert to LLVM options for `ConversionPatternAttr`.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
14+
#define MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H
15+
16+
#include "mlir/Interfaces/TransformsInterfaces.h"
17+
18+
namespace mlir {
19+
class LLVMTypeConverter;
20+
21+
/// Class for passing convert to LLVM options to `ConversionPatternAttr`
22+
/// attributes.
23+
class LLVMConversionPatternAttrOptions : public ConversionPatternAttrOptions {
24+
public:
25+
LLVMConversionPatternAttrOptions(ConversionTarget &target,
26+
LLVMTypeConverter &converter);
27+
28+
static bool classof(ConversionPatternAttrOptions const *opts) {
29+
return opts->getTypeID() == TypeID::get<LLVMConversionPatternAttrOptions>();
30+
}
31+
32+
/// Get the LLVM type converter.
33+
LLVMTypeConverter &getLLVMTypeConverter();
34+
};
35+
} // namespace mlir
36+
37+
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::LLVMConversionPatternAttrOptions)
38+
39+
#endif // MLIR_CONVERSION_LLVMCOMMON_CONVERSIONATTROPTIONS_H

mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Interfaces/InferIntRangeInterface.h"
2929
#include "mlir/Interfaces/InferTypeOpInterface.h"
3030
#include "mlir/Interfaces/SideEffectInterfaces.h"
31+
#include "mlir/Interfaces/TransformsInterfaces.h"
3132
#include "llvm/ADT/STLExtras.h"
3233

3334
namespace mlir {

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include "mlir/Interfaces/FunctionInterfaces.td"
2929
include "mlir/Interfaces/InferIntRangeInterface.td"
3030
include "mlir/Interfaces/InferTypeOpInterface.td"
3131
include "mlir/Interfaces/SideEffectInterfaces.td"
32+
include "mlir/Interfaces/TransformsInterfaces.td"
3233

3334
//===----------------------------------------------------------------------===//
3435
// GPU Dialect operations.
@@ -1347,7 +1348,8 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
13471348

13481349
def GPU_GPUModuleOp : GPU_Op<"module", [
13491350
DataLayoutOpInterface, HasDefaultDLTIDataLayout, IsolatedFromAbove,
1350-
SymbolTable, Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
1351+
DeclareOpInterfaceMethods<OpWithTransformAttrsOpInterface>, SymbolTable,
1352+
Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
13511353
]>, Arguments<(ins SymbolNameAttr:$sym_name,
13521354
OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
13531355
OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler)> {

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
1919
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
2020
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21+
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
2122
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
2223
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2324
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@@ -65,6 +66,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6566
registerConvertMemRefToLLVMInterface(registry);
6667
registerConvertNVVMToLLVMInterface(registry);
6768
ub::registerConvertUBToLLVMInterface(registry);
69+
NVVM::registerConvertGpuToNVVMAttrInterface(registry);
6870

6971
// Register all transform dialect extensions.
7072
affine::registerTransformDialectExtension(registry);

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ mlir_tablegen(DataLayoutTypeInterface.cpp.inc -gen-type-interface-defs)
3636
add_public_tablegen_target(MLIRDataLayoutInterfacesIncGen)
3737
add_dependencies(mlir-generic-headers MLIRDataLayoutInterfacesIncGen)
3838

39+
set(LLVM_TARGET_DEFINITIONS TransformsInterfaces.td)
40+
mlir_tablegen(TransformsAttrInterfaces.h.inc -gen-attr-interface-decls)
41+
mlir_tablegen(TransformsAttrInterfaces.cpp.inc -gen-attr-interface-defs)
42+
mlir_tablegen(TransformsOpInterfaces.h.inc -gen-op-interface-decls)
43+
mlir_tablegen(TransformsOpInterfaces.cpp.inc -gen-op-interface-defs)
44+
add_public_tablegen_target(MLIRTransformsInterfacesIncGen)
45+
add_dependencies(mlir-generic-headers MLIRTransformsInterfacesIncGen)
46+
3947
add_mlir_doc(DataLayoutInterfaces
4048
DataLayoutAttrInterface
4149
Interfaces/
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- TransformsInterfaces.h - Transforms interfaces -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares interfaces for managing transformations, including
10+
// populating pattern rewrites.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_TRANSFORMSINTERFACES_H
15+
#define MLIR_INTERFACES_TRANSFORMSINTERFACES_H
16+
17+
#include "mlir/IR/OpDefinition.h"
18+
19+
namespace mlir {
20+
class ConversionTarget;
21+
class RewritePatternSet;
22+
class TypeConverter;
23+
24+
/// This class serves as an opaque interface for passing options to the
25+
/// `ConversionPatternsAttrInterface` methods. Users of this class must
26+
/// implement the `classof` method as well as using the macros
27+
/// `MLIR_*_EXPLICIT_TYPE_ID` toensure type safeness.
28+
class ConversionPatternAttrOptions {
29+
public:
30+
ConversionPatternAttrOptions(ConversionTarget &target,
31+
TypeConverter &converter);
32+
33+
/// Returns the typeID.
34+
TypeID getTypeID() const { return typeID; }
35+
36+
/// Returns a reference to the conversion target to configure.
37+
ConversionTarget &getConversionTarget() { return target; }
38+
39+
/// Returns a reference to the type converter to configure.
40+
TypeConverter &getTypeConverter() { return converter; }
41+
42+
protected:
43+
/// Derived classes must use this constructor to initialize `typeID` to the
44+
/// appropiate value.
45+
ConversionPatternAttrOptions(TypeID typeID, ConversionTarget &target,
46+
TypeConverter &converter);
47+
// Conversion target.
48+
ConversionTarget &target;
49+
// Type converter.
50+
TypeConverter &converter;
51+
52+
private:
53+
TypeID typeID;
54+
};
55+
56+
/// Helper function for populating dialect conversion patterns. If `op`
57+
/// implements the `OpWithTransformAttrsOpInterface` interface, then the
58+
/// conversion pattern attributes provided by the interface will be used to
59+
/// configure the conversion target, type converter, and the pattern set.
60+
void populateOpConversionPatterns(Operation *op,
61+
ConversionPatternAttrOptions &options,
62+
RewritePatternSet &patterns);
63+
} // namespace mlir
64+
65+
#include "mlir/Interfaces/TransformsAttrInterfaces.h.inc"
66+
67+
#include "mlir/Interfaces/TransformsOpInterfaces.h.inc"
68+
69+
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::ConversionPatternAttrOptions)
70+
71+
#endif // MLIR_INTERFACES_TRANSFORMSINTERFACES_H
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===- TransformsInterfaces.td - Transforms interfaces -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines interfaces for managing transformations, including populating
10+
// pattern rewrites.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
15+
#define MLIR_INTERFACES_TRANSFORMSINTERFACES_TD
16+
17+
include "mlir/IR/OpBase.td"
18+
19+
//===----------------------------------------------------------------------===//
20+
// Conversion patterns attribute interface
21+
//===----------------------------------------------------------------------===//
22+
23+
def ConversionPatternsAttrInterface :
24+
AttrInterface<"ConversionPatternsAttrInterface"> {
25+
let description = [{
26+
This interfaces allows using attributes to configure the dialect conversion
27+
infrastructure, this includes:
28+
- The conversion target.
29+
- The type converter.
30+
- The pattern set.
31+
32+
The conversion target and type converter are passed through the
33+
`ConversionPatternAttrOptions` class. Passing them through this class
34+
and by reference allows sub-classing the base option class, allowing
35+
specializations like `LLVMConversionPatternAttrOptions` for converting to
36+
LLVM.
37+
}];
38+
let cppNamespace = "::mlir";
39+
let methods = [
40+
InterfaceMethod<
41+
/*desc=*/[{
42+
Populate the dialect conversion target, type converter and pattern set.
43+
}],
44+
/*retTy=*/"void",
45+
/*methodName=*/"populateConversionPatterns",
46+
/*args=*/(ins "::mlir::ConversionPatternAttrOptions&":$options,
47+
"::mlir::RewritePatternSet&":$patternSet)>
48+
];
49+
}
50+
51+
//===----------------------------------------------------------------------===//
52+
// Operation with patterns interface
53+
//===----------------------------------------------------------------------===//
54+
55+
def OpWithTransformAttrsOpInterface :
56+
OpInterface<"OpWithTransformAttrsOpInterface"> {
57+
let description = [{
58+
Interface for interacting with transforms attributes. These attributes
59+
allow configuring transformations like dialect conversion with information
60+
present in the IR.
61+
}];
62+
let cppNamespace = "::mlir";
63+
let methods = [
64+
InterfaceMethod<
65+
/*desc=*/[{
66+
Populate the provided vector with a list of conversion pattern
67+
attributes to apply.
68+
}],
69+
/*retTy=*/"void",
70+
/*methodName=*/"getConversionPatternAttrs",
71+
/*args=*/(ins
72+
"::llvm::SmallVectorImpl<::mlir::ConversionPatternsAttrInterface>&":$attrs)
73+
>
74+
];
75+
}
76+
77+
#endif // MLIR_INTERFACES_TRANSFORMSINTERFACES_TD

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1010
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
11+
#include "mlir/Conversion/LLVMCommon/ConversionAttrOptions.h"
1112
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1213
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1314
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1415
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Interfaces/TransformsInterfaces.h"
1517
#include "mlir/Pass/Pass.h"
1618
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
1719
#include "mlir/Transforms/DialectConversion.h"
@@ -61,9 +63,8 @@ class LoadDependentDialectExtension : public DialectExtensionBase {
6163
/// the injection of conversion patterns.
6264
class ConvertToLLVMPass
6365
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
64-
std::shared_ptr<const FrozenRewritePatternSet> patterns;
65-
std::shared_ptr<const ConversionTarget> target;
66-
std::shared_ptr<const LLVMTypeConverter> typeConverter;
66+
std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
67+
interfaces;
6768

6869
public:
6970
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -73,11 +74,8 @@ class ConvertToLLVMPass
7374
}
7475

7576
LogicalResult initialize(MLIRContext *context) final {
76-
RewritePatternSet tempPatterns(context);
77-
auto target = std::make_shared<ConversionTarget>(*context);
78-
target->addLegalDialect<LLVM::LLVMDialect>();
79-
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
80-
77+
auto interfaces =
78+
std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
8179
if (!filterDialects.empty()) {
8280
// Test mode: Populate only patterns from the specified dialects. Produce
8381
// an error if the dialect is not loaded or does not implement the
@@ -92,8 +90,7 @@ class ConvertToLLVMPass
9290
return emitError(UnknownLoc::get(context))
9391
<< "dialect does not implement ConvertToLLVMPatternInterface: "
9492
<< dialectName << "\n";
95-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
96-
tempPatterns);
93+
interfaces->push_back(iface);
9794
}
9895
} else {
9996
// Normal mode: Populate all patterns from all dialects that implement the
@@ -104,20 +101,33 @@ class ConvertToLLVMPass
104101
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
105102
if (!iface)
106103
continue;
107-
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
108-
tempPatterns);
104+
interfaces->push_back(iface);
109105
}
110106
}
111107

112-
this->patterns =
113-
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
114-
this->target = target;
115-
this->typeConverter = typeConverter;
108+
this->interfaces = interfaces;
116109
return success();
117110
}
118111

119112
void runOnOperation() final {
120-
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
113+
MLIRContext *context = &getContext();
114+
RewritePatternSet patterns(context);
115+
ConversionTarget target(*context);
116+
target.addLegalDialect<LLVM::LLVMDialect>();
117+
LLVMTypeConverter typeConverter(context);
118+
119+
// Configure the conversion with dialect level interfaces.
120+
for (ConvertToLLVMPatternInterface *iface : *interfaces)
121+
iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
122+
patterns);
123+
124+
// Configure the conversion attribute interfaces.
125+
LLVMConversionPatternAttrOptions opts(target, typeConverter);
126+
populateOpConversionPatterns(getOperation(), opts, patterns);
127+
128+
// Apply the conversion.
129+
if (failed(applyPartialConversion(getOperation(), target,
130+
std::move(patterns))))
121131
signalPassFailure();
122132
}
123133
};

0 commit comments

Comments
 (0)