1313using namespace mlir ;
1414using namespace mlir ::func;
1515
16- // ===----------------------------------------------------------------------===//
17- // ValueDecomposer
18- // ===----------------------------------------------------------------------===//
19-
20- void ValueDecomposer::decomposeValue (OpBuilder &builder, Location loc,
21- Type type, Value value,
22- SmallVectorImpl<Value> &results) {
23- for (auto &conversion : decomposeValueConversions)
24- if (conversion (builder, loc, type, value, results))
25- return ;
26- results.push_back (value);
27- }
28-
29- // ===----------------------------------------------------------------------===//
30- // DecomposeCallGraphTypesOpConversionPattern
31- // ===----------------------------------------------------------------------===//
32-
33- namespace {
34- // / Base OpConversionPattern class to make a ValueDecomposer available to
35- // / inherited patterns.
36- template <typename SourceOp>
37- class DecomposeCallGraphTypesOpConversionPattern
38- : public OpConversionPattern<SourceOp> {
39- public:
40- DecomposeCallGraphTypesOpConversionPattern (const TypeConverter &typeConverter,
41- MLIRContext *context,
42- ValueDecomposer &decomposer,
43- PatternBenefit benefit = 1 )
44- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45- decomposer (decomposer) {}
46-
47- protected:
48- ValueDecomposer &decomposer;
49- };
50- } // namespace
51-
5216// ===----------------------------------------------------------------------===//
5317// DecomposeCallGraphTypesForFuncArgs
5418// ===----------------------------------------------------------------------===//
5519
5620namespace {
57- // / Expand function arguments according to the provided TypeConverter and
58- // / ValueDecomposer.
21+ // / Expand function arguments according to the provided TypeConverter.
5922struct DecomposeCallGraphTypesForFuncArgs
60- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61- using DecomposeCallGraphTypesOpConversionPattern::
62- DecomposeCallGraphTypesOpConversionPattern;
23+ : public OpConversionPattern<func::FuncOp> {
24+ using OpConversionPattern::OpConversionPattern;
6325
6426 LogicalResult
6527 matchAndRewrite (func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +62,17 @@ struct DecomposeCallGraphTypesForFuncArgs
10062// ===----------------------------------------------------------------------===//
10163
10264namespace {
103- // / Expand return operands according to the provided TypeConverter and
104- // / ValueDecomposer.
65+ // / Expand return operands according to the provided TypeConverter.
10566struct DecomposeCallGraphTypesForReturnOp
106- : public DecomposeCallGraphTypesOpConversionPattern <ReturnOp> {
107- using DecomposeCallGraphTypesOpConversionPattern::
108- DecomposeCallGraphTypesOpConversionPattern;
67+ : public OpConversionPattern <ReturnOp> {
68+ using OpConversionPattern::OpConversionPattern;
69+
10970 LogicalResult
110- matchAndRewrite (ReturnOp op, OpAdaptor adaptor,
71+ matchAndRewrite (ReturnOp op, OneToNOpAdaptor adaptor,
11172 ConversionPatternRewriter &rewriter) const final {
11273 SmallVector<Value, 2 > newOperands;
113- for (Value operand : adaptor.getOperands ())
114- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
115- operand, newOperands);
74+ for (ValueRange operand : adaptor.getOperands ())
75+ llvm::append_range (newOperands, operand);
11676 rewriter.replaceOpWithNewOp <ReturnOp>(op, newOperands);
11777 return success ();
11878 }
@@ -124,74 +84,53 @@ struct DecomposeCallGraphTypesForReturnOp
12484// ===----------------------------------------------------------------------===//
12585
12686namespace {
127- // / Expand call op operands and results according to the provided TypeConverter
128- // / and ValueDecomposer.
87+ // / Expand call op operands and results according to the provided TypeConverter.
12988struct DecomposeCallGraphTypesForCallOp
130- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131- using DecomposeCallGraphTypesOpConversionPattern::
132- DecomposeCallGraphTypesOpConversionPattern;
89+ : public OpConversionPattern<CallOp> {
90+ using OpConversionPattern::OpConversionPattern;
13391
13492 LogicalResult
135- matchAndRewrite (CallOp op, OpAdaptor adaptor,
93+ matchAndRewrite (CallOp op, OneToNOpAdaptor adaptor,
13694 ConversionPatternRewriter &rewriter) const final {
13795
13896 // Create the operands list of the new `CallOp`.
13997 SmallVector<Value, 2 > newOperands;
140- for (Value operand : adaptor.getOperands ())
141- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
142- operand, newOperands);
143-
144- // Create the new result types for the new `CallOp` and track the indices in
145- // the new call op's results that correspond to the old call op's results.
146- //
147- // expandedResultIndices[i] = "list of new result indices that old result i
148- // expanded to".
98+ for (ValueRange operand : adaptor.getOperands ())
99+ llvm::append_range (newOperands, operand);
100+
101+ // Create the new result types for the new `CallOp` and track the number of
102+ // replacement types for each original op result.
149103 SmallVector<Type, 2 > newResultTypes;
150- SmallVector<SmallVector< unsigned , 2 >, 4 > expandedResultIndices ;
104+ SmallVector<unsigned > expandedResultSizes ;
151105 for (Type resultType : op.getResultTypes ()) {
152106 unsigned oldSize = newResultTypes.size ();
153107 if (failed (typeConverter->convertType (resultType, newResultTypes)))
154108 return failure ();
155- auto &resultMapping = expandedResultIndices.emplace_back ();
156- for (unsigned i = oldSize, e = newResultTypes.size (); i < e; i++)
157- resultMapping.push_back (i);
109+ expandedResultSizes.push_back (newResultTypes.size () - oldSize);
158110 }
159111
160112 CallOp newCallOp = rewriter.create <CallOp>(op.getLoc (), op.getCalleeAttr (),
161113 newResultTypes, newOperands);
162114
163- // Build a replacement value for each result to replace its uses. If a
164- // result has multiple mapping values, it needs to be materialized as a
165- // single value.
166- SmallVector<Value, 2 > replacedValues;
115+ // Build a replacement value for each result to replace its uses.
116+ SmallVector<ValueRange> replacedValues;
167117 replacedValues.reserve (op.getNumResults ());
118+ unsigned startIdx = 0 ;
168119 for (unsigned i = 0 , e = op.getNumResults (); i < e; ++i) {
169- auto decomposedValues = llvm::to_vector<6 >(
170- llvm::map_range (expandedResultIndices[i],
171- [&](unsigned i) { return newCallOp.getResult (i); }));
172- if (decomposedValues.empty ()) {
173- // No replacement is required.
174- replacedValues.push_back (nullptr );
175- } else if (decomposedValues.size () == 1 ) {
176- replacedValues.push_back (decomposedValues.front ());
177- } else {
178- // Materialize a single Value to replace the original Value.
179- Value materialized = getTypeConverter ()->materializeArgumentConversion (
180- rewriter, op.getLoc (), op.getType (i), decomposedValues);
181- replacedValues.push_back (materialized);
182- }
120+ ValueRange repl = newCallOp.getResults ().slice (startIdx, expandedResultSizes[i]);
121+ replacedValues.push_back (repl);
122+ startIdx += expandedResultSizes[i];
183123 }
184- rewriter.replaceOp (op, replacedValues);
124+ rewriter.replaceOpWithMultiple (op, replacedValues);
185125 return success ();
186126 }
187127};
188128} // namespace
189129
190130void mlir::populateDecomposeCallGraphTypesPatterns (
191131 MLIRContext *context, const TypeConverter &typeConverter,
192- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
132+ RewritePatternSet &patterns) {
193133 patterns
194134 .add <DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196- decomposer);
135+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197136}
0 commit comments