@@ -14,52 +14,46 @@ using namespace mlir;
1414using namespace mlir ::func;
1515
1616// ===----------------------------------------------------------------------===//
17- // ValueDecomposer
17+ // Helper functions
1818// ===----------------------------------------------------------------------===//
1919
20- void ValueDecomposer::decomposeValue (OpBuilder &builder, Location loc,
21- Type type, Value value,
22- SmallVectorImpl<Value> &results) {
23- for (auto &conversion : decomposeValueConversions)
24- if (conversion (builder, loc, type, value, results))
25- return ;
26- results.push_back (value);
20+ // / If the given value can be decomposed with the type converter, decompose it.
21+ // / Otherwise, return the given value.
22+ static SmallVector<Value> decomposeValue (OpBuilder &builder, Location loc,
23+ Value value,
24+ const TypeConverter *converter) {
25+ // Try to convert the given value's type. If that fails, just return the
26+ // given value.
27+ SmallVector<Type> convertedTypes;
28+ if (failed (converter->convertType (value.getType (), convertedTypes)))
29+ return {value};
30+ if (convertedTypes.empty ())
31+ return {};
32+
33+ // If the given value's type is already legal, just return the given value.
34+ TypeRange convertedTypeRange (convertedTypes);
35+ if (convertedTypeRange == TypeRange (value.getType ()))
36+ return {value};
37+
38+ // Try to materialize a target conversion. If the materialization did not
39+ // produce values of the requested type, the materialization failed. Just
40+ // return the given value in that case.
41+ SmallVector<Value> result = converter->materializeTargetConversion (
42+ builder, loc, convertedTypeRange, value);
43+ if (result.empty ())
44+ return {value};
45+ return result;
2746}
2847
29- // ===----------------------------------------------------------------------===//
30- // DecomposeCallGraphTypesOpConversionPattern
31- // ===----------------------------------------------------------------------===//
32-
33- namespace {
34- // / Base OpConversionPattern class to make a ValueDecomposer available to
35- // / inherited patterns.
36- template <typename SourceOp>
37- class DecomposeCallGraphTypesOpConversionPattern
38- : public OpConversionPattern<SourceOp> {
39- public:
40- DecomposeCallGraphTypesOpConversionPattern (const TypeConverter &typeConverter,
41- MLIRContext *context,
42- ValueDecomposer &decomposer,
43- PatternBenefit benefit = 1 )
44- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45- decomposer (decomposer) {}
46-
47- protected:
48- ValueDecomposer &decomposer;
49- };
50- } // namespace
51-
5248// ===----------------------------------------------------------------------===//
5349// DecomposeCallGraphTypesForFuncArgs
5450// ===----------------------------------------------------------------------===//
5551
5652namespace {
57- // / Expand function arguments according to the provided TypeConverter and
58- // / ValueDecomposer.
53+ // / Expand function arguments according to the provided TypeConverter.
5954struct DecomposeCallGraphTypesForFuncArgs
60- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61- using DecomposeCallGraphTypesOpConversionPattern::
62- DecomposeCallGraphTypesOpConversionPattern;
55+ : public OpConversionPattern<func::FuncOp> {
56+ using OpConversionPattern::OpConversionPattern;
6357
6458 LogicalResult
6559 matchAndRewrite (func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +94,22 @@ struct DecomposeCallGraphTypesForFuncArgs
10094// ===----------------------------------------------------------------------===//
10195
10296namespace {
103- // / Expand return operands according to the provided TypeConverter and
104- // / ValueDecomposer.
97+ // / Expand return operands according to the provided TypeConverter.
10598struct DecomposeCallGraphTypesForReturnOp
106- : public DecomposeCallGraphTypesOpConversionPattern <ReturnOp> {
107- using DecomposeCallGraphTypesOpConversionPattern::
108- DecomposeCallGraphTypesOpConversionPattern;
99+ : public OpConversionPattern <ReturnOp> {
100+ using OpConversionPattern::OpConversionPattern;
101+
109102 LogicalResult
110103 matchAndRewrite (ReturnOp op, OpAdaptor adaptor,
111104 ConversionPatternRewriter &rewriter) const final {
112105 SmallVector<Value, 2 > newOperands;
113- for (Value operand : adaptor.getOperands ())
114- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
115- operand, newOperands);
106+ for (Value operand : adaptor.getOperands ()) {
107+ // TODO: We can directly take the values from the adaptor once this is a
108+ // 1:N conversion pattern.
109+ llvm::append_range (newOperands,
110+ decomposeValue (rewriter, operand.getLoc (), operand,
111+ getTypeConverter ()));
112+ }
116113 rewriter.replaceOpWithNewOp <ReturnOp>(op, newOperands);
117114 return success ();
118115 }
@@ -124,22 +121,23 @@ struct DecomposeCallGraphTypesForReturnOp
124121// ===----------------------------------------------------------------------===//
125122
126123namespace {
127- // / Expand call op operands and results according to the provided TypeConverter
128- // / and ValueDecomposer.
129- struct DecomposeCallGraphTypesForCallOp
130- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131- using DecomposeCallGraphTypesOpConversionPattern::
132- DecomposeCallGraphTypesOpConversionPattern;
124+ // / Expand call op operands and results according to the provided TypeConverter.
125+ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern <CallOp> {
126+ using OpConversionPattern::OpConversionPattern;
133127
134128 LogicalResult
135129 matchAndRewrite (CallOp op, OpAdaptor adaptor,
136130 ConversionPatternRewriter &rewriter) const final {
137131
138132 // Create the operands list of the new `CallOp`.
139133 SmallVector<Value, 2 > newOperands;
140- for (Value operand : adaptor.getOperands ())
141- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
142- operand, newOperands);
134+ for (Value operand : adaptor.getOperands ()) {
135+ // TODO: We can directly take the values from the adaptor once this is a
136+ // 1:N conversion pattern.
137+ llvm::append_range (newOperands,
138+ decomposeValue (rewriter, operand.getLoc (), operand,
139+ getTypeConverter ()));
140+ }
143141
144142 // Create the new result types for the new `CallOp` and track the indices in
145143 // the new call op's results that correspond to the old call op's results.
@@ -189,9 +187,8 @@ struct DecomposeCallGraphTypesForCallOp
189187
190188void mlir::populateDecomposeCallGraphTypesPatterns (
191189 MLIRContext *context, const TypeConverter &typeConverter,
192- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
190+ RewritePatternSet &patterns) {
193191 patterns
194192 .add <DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196- decomposer);
193+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197194}
0 commit comments