@@ -34,6 +34,16 @@ namespace {
34
34
// Helper methods.
35
35
// ===----------------------------------------------------------------------===//
36
36
37
+ // / Reorders stored dimension to original dimension.
38
+ static unsigned toOrig (const SparseTensorEncodingAttr &enc, unsigned i) {
39
+ auto order = enc.getDimOrdering ();
40
+ if (order) {
41
+ assert (order.isPermutation ());
42
+ return order.getDimPosition (i);
43
+ }
44
+ return i;
45
+ }
46
+
37
47
// / Reorders original dimension to stored dimension.
38
48
static unsigned toStored (const SparseTensorEncodingAttr &enc, unsigned i) {
39
49
auto order = enc.getDimOrdering ();
@@ -87,7 +97,7 @@ static Optional<Type> convertSparseTensorType(Type type) {
87
97
// tensor type.
88
98
switch (enc.getDimLevelType ()[r]) {
89
99
case SparseTensorEncodingAttr::DimLevelType::Dense:
90
- break ;
100
+ break ; // no fields
91
101
case SparseTensorEncodingAttr::DimLevelType::Compressed:
92
102
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
93
103
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
@@ -111,7 +121,7 @@ static Optional<Type> convertSparseTensorType(Type type) {
111
121
return TupleType::get (context, fields);
112
122
}
113
123
114
- // Returns field index for pointers (d), indices (d) for set field .
124
+ // Returns field index of sparse tensor type for pointers/indices, when set .
115
125
static unsigned getFieldIndex (Type type, unsigned ptrDim, unsigned idxDim) {
116
126
auto enc = getSparseTensorEncoding (type);
117
127
assert (enc);
@@ -161,6 +171,94 @@ static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
161
171
builder.getIntegerAttr (indexType, field));
162
172
}
163
173
174
+ // / Creates tuple.
175
+ static Value createTupleMake (OpBuilder &builder, Location loc, Type type,
176
+ ValueRange values) {
177
+ return builder.create <StorageNewOp>(loc, type, values);
178
+ }
179
+
180
+ // / Create allocation operation.
181
+ static Value createAllocation (OpBuilder &builder, Location loc, Type type,
182
+ Value sz) {
183
+ auto memType = MemRefType::get ({ShapedType::kDynamicSize }, type);
184
+ return builder.create <memref::AllocOp>(loc, memType, sz);
185
+ }
186
+
187
+ // / Creates allocation tuple for sparse tensor type.
188
+ // /
189
+ // / TODO: for efficiency, we will need heuristis to make educated guesses
190
+ // / on the required final sizes; also, we will need an improved
191
+ // / memory allocation scheme with capacity and reallocation
192
+ // /
193
+ static Value createAllocTuple (OpBuilder &builder, Location loc, Type type,
194
+ ValueRange dynSizes) {
195
+ auto enc = getSparseTensorEncoding (type);
196
+ assert (enc);
197
+ // Construct the basic types.
198
+ unsigned idxWidth = enc.getIndexBitWidth ();
199
+ unsigned ptrWidth = enc.getPointerBitWidth ();
200
+ RankedTensorType rType = type.cast <RankedTensorType>();
201
+ Type indexType = builder.getIndexType ();
202
+ Type idxType = idxWidth ? builder.getIntegerType (idxWidth) : indexType;
203
+ Type ptrType = ptrWidth ? builder.getIntegerType (ptrWidth) : indexType;
204
+ Type eltType = rType.getElementType ();
205
+ // Build the allocation tuple, using heuristics for pre-allocation.
206
+ auto shape = rType.getShape ();
207
+ unsigned rank = shape.size ();
208
+ SmallVector<Value, 8 > fields;
209
+ bool allDense = true ;
210
+ Value one = constantIndex (builder, loc, 1 );
211
+ Value linear = one;
212
+ Value heuristic = one; // FIX, see TODO above
213
+ // Build original sizes.
214
+ SmallVector<Value, 8 > sizes;
215
+ for (unsigned r = 0 , o = 0 ; r < rank; r++) {
216
+ if (ShapedType::isDynamic (shape[r]))
217
+ sizes.push_back (dynSizes[o++]);
218
+ else
219
+ sizes.push_back (constantIndex (builder, loc, shape[r]));
220
+ }
221
+ // The dimSizes array.
222
+ Value dimSizes =
223
+ builder.create <memref::AllocOp>(loc, MemRefType::get ({rank}, indexType));
224
+ fields.push_back (dimSizes);
225
+ // Per-dimension storage.
226
+ for (unsigned r = 0 ; r < rank; r++) {
227
+ // Get the original dimension (ro) for the current stored dimension.
228
+ unsigned ro = toOrig (enc, r);
229
+ builder.create <memref::StoreOp>(loc, sizes[ro], dimSizes,
230
+ constantIndex (builder, loc, r));
231
+ linear = builder.create <arith::MulIOp>(loc, linear, sizes[ro]);
232
+ // Allocate fiels.
233
+ switch (enc.getDimLevelType ()[r]) {
234
+ case SparseTensorEncodingAttr::DimLevelType::Dense:
235
+ break ; // no fields
236
+ case SparseTensorEncodingAttr::DimLevelType::Compressed:
237
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
238
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
239
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
240
+ fields.push_back (createAllocation (builder, loc, ptrType, heuristic));
241
+ fields.push_back (createAllocation (builder, loc, idxType, heuristic));
242
+ allDense = false ;
243
+ break ;
244
+ case SparseTensorEncodingAttr::DimLevelType::Singleton:
245
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
246
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
247
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
248
+ fields.push_back (createAllocation (builder, loc, idxType, heuristic));
249
+ allDense = false ;
250
+ break ;
251
+ }
252
+ }
253
+ // The values array. For all-dense, the full length is required.
254
+ // In all other case, we resort to the heuristical initial value.
255
+ Value valuesSz = allDense ? linear : heuristic;
256
+ fields.push_back (createAllocation (builder, loc, eltType, valuesSz));
257
+ // Construct tuple allocation.
258
+ Type tupleType = *convertSparseTensorType (type);
259
+ return createTupleMake (builder, loc, tupleType, fields);
260
+ }
261
+
164
262
// / Returns integral constant, if defined.
165
263
static Optional<int64_t > getConstantInt (Value val) {
166
264
if (auto constantOp = val.getDefiningOp <arith::ConstantOp>())
@@ -233,6 +331,28 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
233
331
}
234
332
};
235
333
334
+ // / Sparse codgen rule for the alloc operator.
335
+ class SparseTensorAllocConverter
336
+ : public OpConversionPattern<bufferization::AllocTensorOp> {
337
+ public:
338
+ using OpConversionPattern::OpConversionPattern;
339
+ LogicalResult
340
+ matchAndRewrite (bufferization::AllocTensorOp op, OpAdaptor adaptor,
341
+ ConversionPatternRewriter &rewriter) const override {
342
+ RankedTensorType resType = op.getType ();
343
+ auto enc = getSparseTensorEncoding (resType);
344
+ if (!enc)
345
+ return failure ();
346
+ if (op.getCopy ())
347
+ return rewriter.notifyMatchFailure (op, " tensor copy not implemented" );
348
+ // Construct allocation tuple.
349
+ Value tuple = createAllocTuple (rewriter, op->getLoc (), resType,
350
+ adaptor.getOperands ());
351
+ rewriter.replaceOp (op, tuple);
352
+ return success ();
353
+ }
354
+ };
355
+
236
356
// / Sparse codegen rule for the dealloc operator.
237
357
class SparseTensorDeallocConverter
238
358
: public OpConversionPattern<bufferization::DeallocTensorOp> {
@@ -311,6 +431,22 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
311
431
}
312
432
};
313
433
434
+ // / Sparse codegen rule for tensor rematerialization.
435
+ class SparseTensorLoadConverter : public OpConversionPattern <LoadOp> {
436
+ public:
437
+ using OpConversionPattern::OpConversionPattern;
438
+ LogicalResult
439
+ matchAndRewrite (LoadOp op, OpAdaptor adaptor,
440
+ ConversionPatternRewriter &rewriter) const override {
441
+ if (op.getHasInserts ()) {
442
+ // Finalize any pending insertions.
443
+ // TODO: implement
444
+ }
445
+ rewriter.replaceOp (op, adaptor.getOperands ());
446
+ return success ();
447
+ }
448
+ };
449
+
314
450
} // namespace
315
451
316
452
// ===----------------------------------------------------------------------===//
@@ -331,7 +467,8 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
331
467
void mlir::populateSparseTensorCodegenPatterns (TypeConverter &typeConverter,
332
468
RewritePatternSet &patterns) {
333
469
patterns.add <SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
334
- SparseTensorDeallocConverter, SparseToPointersConverter,
335
- SparseToIndicesConverter, SparseToValuesConverter>(
470
+ SparseTensorAllocConverter, SparseTensorDeallocConverter,
471
+ SparseToPointersConverter, SparseToIndicesConverter,
472
+ SparseToValuesConverter, SparseTensorLoadConverter>(
336
473
typeConverter, patterns.getContext ());
337
474
}
0 commit comments