@@ -16,20 +16,16 @@ using namespace mlir::scf;
1616
1717namespace {
1818
19- // Unpacks the single unrealized_conversion_cast using the list of inputs
20- // e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
21- static void unpackUnrealizedConversionCast (Value v,
22- SmallVectorImpl<Value> &unpacked) {
23- if (auto cast =
24- dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp ())) {
25- if (cast.getInputs ().size () != 1 ) {
26- // 1 : N type conversion.
27- unpacked.append (cast.getInputs ().begin (), cast.getInputs ().end ());
28- return ;
29- }
30- }
31- // 1 : 1 type conversion.
32- unpacked.push_back (v);
19+ static SmallVector<Value> flattenValues (ArrayRef<ArrayRef<Value>> values) {
20+ SmallVector<Value> result;
21+ for (ArrayRef<Value> v : values)
22+ llvm::append_range (result, v);
23+ return result;
24+ }
25+
26+ static Value getSingleValue (ArrayRef<Value> values) {
27+ assert (values.size () == 1 && " expected single value" );
28+ return values.front ();
3329}
3430
3531// CRTP
@@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
4036public:
4137 using OpConversionPattern<SourceOp>::typeConverter;
4238 using OpConversionPattern<SourceOp>::OpConversionPattern;
43- using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
39+ using OneToNOpAdaptor =
40+ typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
4441
4542 //
4643 // Derived classes should provide the following method which performs the
4744 // actual conversion. It should return std::nullopt upon conversion failure
4845 // and return the converted operation upon success.
4946 //
50- // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
51- // ConversionPatternRewriter &rewriter,
52- // TypeRange dstTypes) const;
47+ // std::optional<SourceOp> convertSourceOp(
48+ // SourceOp op, OneToNOpAdaptor adaptor,
49+ // ConversionPatternRewriter &rewriter,
50+ // TypeRange dstTypes) const;
5351
5452 LogicalResult
55- matchAndRewrite (SourceOp op, OpAdaptor adaptor,
53+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
5654 ConversionPatternRewriter &rewriter) const override {
5755 SmallVector<Type> dstTypes;
5856 SmallVector<unsigned > offsets;
@@ -73,28 +71,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
7371 return rewriter.notifyMatchFailure (op, " could not convert operation" );
7472
7573 // Packs the return value.
76- SmallVector<Value > packedRets;
74+ SmallVector<ValueRange > packedRets;
7775 for (unsigned i = 1 , e = offsets.size (); i < e; i++) {
7876 unsigned start = offsets[i - 1 ], end = offsets[i];
7977 unsigned len = end - start;
8078 ValueRange mappedValue = newOp->getResults ().slice (start, len);
81- if (len != 1 ) {
82- // 1 : N type conversion.
83- Type origType = op.getResultTypes ()[i - 1 ];
84- Value mat = typeConverter->materializeSourceConversion (
85- rewriter, op.getLoc (), origType, mappedValue);
86- if (!mat) {
87- return rewriter.notifyMatchFailure (
88- op, " Failed to materialize 1:N type conversion" );
89- }
90- packedRets.push_back (mat);
91- } else {
92- // 1 : 1 type conversion.
93- packedRets.push_back (mappedValue.front ());
94- }
79+ packedRets.push_back (mappedValue);
9580 }
9681
97- rewriter.replaceOp (op, packedRets);
82+ rewriter.replaceOpWithMultiple (op, packedRets);
9883 return success ();
9984 }
10085};
@@ -105,7 +90,7 @@ class ConvertForOpTypes
10590 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
10691
10792 // The callback required by CRTP.
108- std::optional<ForOp> convertSourceOp (ForOp op, OpAdaptor adaptor,
93+ std::optional<ForOp> convertSourceOp (ForOp op, OneToNOpAdaptor adaptor,
10994 ConversionPatternRewriter &rewriter,
11095 TypeRange dstTypes) const {
11196 // Create a empty new op and inline the regions from the old op.
@@ -129,16 +114,13 @@ class ConvertForOpTypes
129114 if (failed (rewriter.convertRegionTypes (&op.getRegion (), *typeConverter)))
130115 return std::nullopt ;
131116
132- // Unpacked the iteration arguments.
133- SmallVector<Value> flatArgs;
134- for (Value arg : adaptor.getInitArgs ())
135- unpackUnrealizedConversionCast (arg, flatArgs);
136-
137117 // We can not do clone as the number of result types after conversion
138118 // might be different.
139- ForOp newOp = rewriter.create <ForOp>(op.getLoc (), adaptor.getLowerBound (),
140- adaptor.getUpperBound (),
141- adaptor.getStep (), flatArgs);
119+ ForOp newOp = rewriter.create <ForOp>(
120+ op.getLoc (), getSingleValue (adaptor.getLowerBound ()),
121+ getSingleValue (adaptor.getUpperBound ()),
122+ getSingleValue (adaptor.getStep ()),
123+ flattenValues (adaptor.getInitArgs ()));
142124
143125 // Reserve whatever attributes in the original op.
144126 newOp->setAttrs (op->getAttrs ());
@@ -160,12 +142,12 @@ class ConvertIfOpTypes
160142public:
161143 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
162144
163- std::optional<IfOp> convertSourceOp (IfOp op, OpAdaptor adaptor,
145+ std::optional<IfOp> convertSourceOp (IfOp op, OneToNOpAdaptor adaptor,
164146 ConversionPatternRewriter &rewriter,
165147 TypeRange dstTypes) const {
166148
167- IfOp newOp = rewriter.create <IfOp>(op. getLoc (), dstTypes,
168- adaptor.getCondition (), true );
149+ IfOp newOp = rewriter.create <IfOp>(
150+ op. getLoc (), dstTypes, getSingleValue ( adaptor.getCondition () ), true );
169151 newOp->setAttrs (op->getAttrs ());
170152
171153 // We do not need the empty blocks created by rewriter.
@@ -189,15 +171,11 @@ class ConvertWhileOpTypes
189171public:
190172 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
191173
192- std::optional<WhileOp> convertSourceOp (WhileOp op, OpAdaptor adaptor,
174+ std::optional<WhileOp> convertSourceOp (WhileOp op, OneToNOpAdaptor adaptor,
193175 ConversionPatternRewriter &rewriter,
194176 TypeRange dstTypes) const {
195- // Unpacked the iteration arguments.
196- SmallVector<Value> flatArgs;
197- for (Value arg : adaptor.getOperands ())
198- unpackUnrealizedConversionCast (arg, flatArgs);
199-
200- auto newOp = rewriter.create <WhileOp>(op.getLoc (), dstTypes, flatArgs);
177+ auto newOp = rewriter.create <WhileOp>(op.getLoc (), dstTypes,
178+ flattenValues (adaptor.getOperands ()));
201179
202180 for (auto i : {0u , 1u }) {
203181 if (failed (rewriter.convertRegionTypes (&op.getRegion (i), *typeConverter)))
@@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
218196public:
219197 using OpConversionPattern::OpConversionPattern;
220198 LogicalResult
221- matchAndRewrite (scf::YieldOp op, OpAdaptor adaptor,
199+ matchAndRewrite (scf::YieldOp op, OneToNOpAdaptor adaptor,
222200 ConversionPatternRewriter &rewriter) const override {
223- SmallVector<Value> unpackedYield;
224- for (Value operand : adaptor.getOperands ())
225- unpackUnrealizedConversionCast (operand, unpackedYield);
226-
227- rewriter.replaceOpWithNewOp <scf::YieldOp>(op, unpackedYield);
201+ rewriter.replaceOpWithNewOp <scf::YieldOp>(
202+ op, flattenValues (adaptor.getOperands ()));
228203 return success ();
229204 }
230205};
@@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
235210public:
236211 using OpConversionPattern<ConditionOp>::OpConversionPattern;
237212 LogicalResult
238- matchAndRewrite (ConditionOp op, OpAdaptor adaptor,
213+ matchAndRewrite (ConditionOp op, OneToNOpAdaptor adaptor,
239214 ConversionPatternRewriter &rewriter) const override {
240- SmallVector<Value> unpackedYield;
241- for (Value operand : adaptor.getOperands ())
242- unpackUnrealizedConversionCast (operand, unpackedYield);
243-
244- rewriter.modifyOpInPlace (op, [&]() { op->setOperands (unpackedYield); });
215+ rewriter.modifyOpInPlace (
216+ op, [&]() { op->setOperands (flattenValues (adaptor.getOperands ())); });
245217 return success ();
246218 }
247219};
0 commit comments