@@ -41,10 +41,69 @@ convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
41
41
return llvm::None;
42
42
}
43
43
44
+ // / Flatten a list of operands that may contain tuples.
45
+ static void flattenOperands (ValueRange operands,
46
+ SmallVectorImpl<Value> &flattened) {
47
+ // In case of
48
+ // tuple<a, b>, c, tuple<d, e>
49
+ // ==>
50
+ // a, b, c, d, e
51
+ for (auto operand : operands) {
52
+ if (auto cast =
53
+ dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp ());
54
+ cast && cast->getResultTypes ()[0 ].isa <TupleType>())
55
+ // An unrealized_conversion_cast will be inserted by type converter to
56
+ // inter-mix the gap between 1:N conversion between tuple and types.
57
+ // In this case, take the operands in the cast and replace the tuple
58
+ // output with the flattened type array.
59
+ flattened.append (cast.getOperands ().begin (), cast.getOperands ().end ());
60
+ else
61
+ flattened.push_back (operand);
62
+ }
63
+ }
44
64
// ===----------------------------------------------------------------------===//
45
65
// Conversion rules.
46
66
// ===----------------------------------------------------------------------===//
47
67
68
+ // / Sparse tensor storage conversion rule for sparse_tensor::storage_get.
69
+ class SparseStorageGetConverter : public OpConversionPattern <StorageGetOp> {
70
+ public:
71
+ using OpConversionPattern::OpConversionPattern;
72
+ LogicalResult
73
+ matchAndRewrite (StorageGetOp op, OpAdaptor adaptor,
74
+ ConversionPatternRewriter &rewriter) const override {
75
+ auto castOp =
76
+ cast<UnrealizedConversionCastOp>(adaptor.getStorage ().getDefiningOp ());
77
+ uint64_t idx = op.getIdx ().getZExtValue ();
78
+ assert (idx < castOp.getOperands ().size ());
79
+
80
+ rewriter.replaceOp (op, castOp.getOperand (idx));
81
+ return success ();
82
+ }
83
+ };
84
+
85
+ // / Sparse tensor storage conversion rule for sparse_tensor::storage_set.
86
+ class SparseStorageSetConverter : public OpConversionPattern <StorageSetOp> {
87
+ public:
88
+ using OpConversionPattern::OpConversionPattern;
89
+ LogicalResult
90
+ matchAndRewrite (StorageSetOp op, OpAdaptor adaptor,
91
+ ConversionPatternRewriter &rewriter) const override {
92
+ auto castOp =
93
+ cast<UnrealizedConversionCastOp>(adaptor.getStorage ().getDefiningOp ());
94
+ uint64_t idx = op.getIdx ().getZExtValue ();
95
+
96
+ SmallVector<Value, 8 > values (castOp.getOperands ());
97
+ assert (idx < values.size ());
98
+
99
+ // Updates the corresponding element.
100
+ values[idx] = adaptor.getValue ();
101
+ rewriter.replaceOpWithNewOp <UnrealizedConversionCastOp>(
102
+ op, TypeRange{op.getType ()}, values);
103
+ return success ();
104
+ }
105
+ };
106
+
48
107
// / Sparse tensor storage conversion rule for returns.
49
108
class SparseStorageReturnConverter
50
109
: public OpConversionPattern<func::ReturnOp> {
@@ -54,24 +113,69 @@ class SparseStorageReturnConverter
54
113
matchAndRewrite (func::ReturnOp op, OpAdaptor adaptor,
55
114
ConversionPatternRewriter &rewriter) const override {
56
115
SmallVector<Value, 8 > flattened;
57
- for (auto operand : adaptor.getOperands ()) {
58
- if (auto cast =
59
- dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp ());
60
- cast && cast->getResultTypes ()[0 ].isa <TupleType>())
61
- // An unrealized_conversion_cast will be inserted by type converter to
62
- // inter-mix the gap between 1:N conversion between tuple and types.
63
- // In this case, take the operands in the cast and replace the tuple
64
- // output with the flattened type array.
65
- flattened.append (cast.getOperands ().begin (), cast.getOperands ().end ());
66
- else
67
- flattened.push_back (operand);
68
- }
116
+ flattenOperands (adaptor.getOperands (), flattened);
69
117
// Create a return with the flattened value extracted from tuple.
70
118
rewriter.replaceOpWithNewOp <func::ReturnOp>(op, flattened);
71
119
return success ();
72
120
}
73
121
};
74
122
123
+ // / Sparse tensor storage conversion rule for calls.
124
+ class SparseStorageCallConverter : public OpConversionPattern <func::CallOp> {
125
+ public:
126
+ // The default CallOp converter can not handle 1:N type conversion properly
127
+ using OpConversionPattern::OpConversionPattern;
128
+ LogicalResult
129
+ matchAndRewrite (func::CallOp op, OpAdaptor adaptor,
130
+ ConversionPatternRewriter &rewriter) const override {
131
+ Location loc = op.getLoc ();
132
+ // In case of:
133
+ // tuple(a, b), f, tuple(c, d) = call @foo(...)
134
+ // ==>
135
+ // a, b, f, c, d = call @foo(...)
136
+ // cast(a, b)->tuple, f, cast(c,d)->tuple
137
+ SmallVector<Type, 8 > finalRetTy;
138
+ if (failed (typeConverter->convertTypes (op.getResultTypes (), finalRetTy)))
139
+ return failure ();
140
+
141
+ // (1) Genereates new call with flattened return value.
142
+ SmallVector<Value, 8 > flattened;
143
+ flattenOperands (adaptor.getOperands (), flattened);
144
+ auto newCall = rewriter.create <func::CallOp>(loc, op.getCallee (),
145
+ finalRetTy, flattened);
146
+
147
+ // (2) Create cast operation for tuple returns.
148
+ SmallVector<Value, 4 > castedRet;
149
+ // Tracks the offset of current return value (of the orignal call)
150
+ // relative to the new call (after tuple flattening);
151
+ unsigned retOffset = 0 ;
152
+ for (auto ret : op.getResults ()) {
153
+ assert (retOffset < newCall.getNumResults ());
154
+ auto tupleRet = ret.getType ().dyn_cast <TupleType>();
155
+ if (tupleRet) {
156
+ auto tupleSize = tupleRet.size ();
157
+ // NOTE: The range is computed under the assumption of non-recursive
158
+ // tuple type.
159
+ ValueRange tupleElem (iterator_range<ResultRange::iterator>(
160
+ newCall.result_begin () + retOffset,
161
+ newCall.result_begin () + retOffset + tupleSize));
162
+ auto castOp = rewriter.create <UnrealizedConversionCastOp>(
163
+ loc, TypeRange ({tupleRet}), tupleElem);
164
+ castedRet.push_back (castOp.getResult (0 ));
165
+ retOffset += tupleSize;
166
+ } else {
167
+ // If this not a tuple, simply add it into returned values.
168
+ castedRet.push_back (ret);
169
+ retOffset++;
170
+ }
171
+ }
172
+
173
+ assert (castedRet.size () == op.getNumResults ());
174
+ rewriter.replaceOp (op, castedRet);
175
+ return success ();
176
+ }
177
+ };
178
+
75
179
} // namespace
76
180
77
181
// ===----------------------------------------------------------------------===//
@@ -91,6 +195,7 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
91
195
// / to expand compounded sparse tensor tuples.
92
196
void mlir::populateSparseTensorStorageExpansionPatterns (
93
197
TypeConverter &typeConverter, RewritePatternSet &patterns) {
94
- patterns.add <SparseStorageReturnConverter>(typeConverter,
95
- patterns.getContext ());
198
+ patterns.add <SparseStorageGetConverter, SparseStorageSetConverter,
199
+ SparseStorageReturnConverter, SparseStorageCallConverter>(
200
+ typeConverter, patterns.getContext ());
96
201
}
0 commit comments