|
10 | 10 | // visible buffers and actual compiler IR that implements these primitives on |
11 | 11 | // the selected sparse tensor storage schemes. This pass provides an alternative |
12 | 12 | // to the SparseTensorConversion pass, eliminating the dependence on a runtime |
13 | | -// support library, and providing much more opportunities for subsequent |
14 | | -// compiler optimization of the generated code. |
| 13 | +// support library (other than for file I/O), and providing many more |
| 14 | +// opportunities for subsequent compiler optimization of the generated code. |
15 | 15 | // |
16 | 16 | //===----------------------------------------------------------------------===// |
17 | 17 |
|
|
37 | 37 | using namespace mlir; |
38 | 38 | using namespace mlir::sparse_tensor; |
39 | 39 |
|
40 | | -namespace { |
41 | | - |
42 | | -using FuncGeneratorType = |
43 | | - function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>; |
44 | | - |
45 | 40 | //===----------------------------------------------------------------------===// |
46 | 41 | // Helper methods. |
47 | 42 | //===----------------------------------------------------------------------===// |
48 | 43 |
|
49 | | -/// Flatten a list of operands that may contain sparse tensors. |
| 44 | +/// Flattens a list of operands that may contain sparse tensors. |
50 | 45 | static void flattenOperands(ValueRange operands, |
51 | 46 | SmallVectorImpl<Value> &flattened) { |
52 | 47 | // In case of |
@@ -97,6 +92,7 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, |
97 | 92 | return forOp; |
98 | 93 | } |
99 | 94 |
|
| 95 | +/// Creates a push back operation. |
100 | 96 | static void createPushback(OpBuilder &builder, Location loc, |
101 | 97 | MutSparseTensorDescriptor desc, |
102 | 98 | SparseTensorFieldKind kind, std::optional<Level> lvl, |
@@ -368,6 +364,95 @@ static Value genCompressed(OpBuilder &builder, Location loc, |
368 | 364 | return ifOp2.getResult(o); |
369 | 365 | } |
370 | 366 |
|
| 367 | +/// Generates insertion finalization code. |
| 368 | +static void genEndInsert(OpBuilder &builder, Location loc, |
| 369 | + SparseTensorDescriptor desc) { |
| 370 | + const SparseTensorType stt(desc.getRankedTensorType()); |
| 371 | + const Level lvlRank = stt.getLvlRank(); |
| 372 | + for (Level l = 0; l < lvlRank; l++) { |
| 373 | + const auto dlt = stt.getLvlType(l); |
| 374 | + if (isLooseCompressedDLT(dlt)) |
| 375 | + llvm_unreachable("TODO: Not yet implemented"); |
| 376 | + if (isCompressedDLT(dlt)) { |
| 377 | + // Compressed dimensions need a position cleanup for all entries |
| 378 | + // that were not visited during the insertion pass. |
| 379 | + // |
| 380 | + // TODO: avoid cleanup and keep compressed scheme consistent at all |
| 381 | + // times? |
| 382 | + // |
| 383 | + if (l > 0) { |
| 384 | + Type posType = stt.getPosType(); |
| 385 | + Value posMemRef = desc.getPosMemRef(l); |
| 386 | + Value hi = desc.getPosMemSize(builder, loc, l); |
| 387 | + Value zero = constantIndex(builder, loc, 0); |
| 388 | + Value one = constantIndex(builder, loc, 1); |
| 389 | + // Vector of only one, but needed by createFor's prototype. |
| 390 | + SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)}; |
| 391 | + scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
| 392 | + Value i = loop.getInductionVar(); |
| 393 | + Value oldv = loop.getRegionIterArg(0); |
| 394 | + Value newv = genLoad(builder, loc, posMemRef, i); |
| 395 | + Value posZero = constantZero(builder, loc, posType); |
| 396 | + Value cond = builder.create<arith::CmpIOp>( |
| 397 | + loc, arith::CmpIPredicate::eq, newv, posZero); |
| 398 | + scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
| 399 | + cond, /*else*/ true); |
| 400 | + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 401 | + genStore(builder, loc, oldv, posMemRef, i); |
| 402 | + builder.create<scf::YieldOp>(loc, oldv); |
| 403 | + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 404 | + builder.create<scf::YieldOp>(loc, newv); |
| 405 | + builder.setInsertionPointAfter(ifOp); |
| 406 | + builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
| 407 | + builder.setInsertionPointAfter(loop); |
| 408 | + } |
| 409 | + } else { |
| 410 | + assert(isDenseDLT(dlt) || isSingletonDLT(dlt)); |
| 411 | + } |
| 412 | + } |
| 413 | +} |
| 414 | + |
| 415 | +/// Generates a subview into the sizes. |
| 416 | +static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
| 417 | + Value sz) { |
| 418 | + auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType(); |
| 419 | + return builder |
| 420 | + .create<memref::SubViewOp>( |
| 421 | + loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, |
| 422 | + ValueRange{}, ValueRange{sz}, ValueRange{}, |
| 423 | + ArrayRef<int64_t>{0}, // static offset |
| 424 | + ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
| 425 | + ArrayRef<int64_t>{1}) // static stride |
| 426 | + .getResult(); |
| 427 | +} |
| 428 | + |
| 429 | +/// Creates the reassociation array. |
| 430 | +static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) { |
| 431 | + ReassociationIndices reassociation; |
| 432 | + for (int i = 0, e = srcTp.getRank(); i < e; i++) |
| 433 | + reassociation.push_back(i); |
| 434 | + return reassociation; |
| 435 | +} |
| 436 | + |
| 437 | +/// Generates scalar to tensor cast. |
| 438 | +static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, |
| 439 | + Type dstTp) { |
| 440 | + if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) { |
| 441 | + // Scalars can only be converted to 0-ranked tensors. |
| 442 | + if (rtp.getRank() != 0) |
| 443 | + return nullptr; |
| 444 | + elem = genCast(builder, loc, elem, rtp.getElementType()); |
| 445 | + return builder.create<tensor::FromElementsOp>(loc, rtp, elem); |
| 446 | + } |
| 447 | + return genCast(builder, loc, elem, dstTp); |
| 448 | +} |
| 449 | + |
| 450 | +//===----------------------------------------------------------------------===// |
| 451 | +// Codegen rules. |
| 452 | +//===----------------------------------------------------------------------===// |
| 453 | + |
| 454 | +namespace { |
| 455 | + |
371 | 456 | /// Helper class to help lowering sparse_tensor.insert operation. |
372 | 457 | class SparseInsertGenerator |
373 | 458 | : public FuncCallOrInlineGenerator<SparseInsertGenerator> { |
@@ -472,90 +557,6 @@ class SparseInsertGenerator |
472 | 557 | TensorType rtp; |
473 | 558 | }; |
474 | 559 |
|
475 | | -/// Generations insertion finalization code. |
476 | | -static void genEndInsert(OpBuilder &builder, Location loc, |
477 | | - SparseTensorDescriptor desc) { |
478 | | - const SparseTensorType stt(desc.getRankedTensorType()); |
479 | | - const Level lvlRank = stt.getLvlRank(); |
480 | | - for (Level l = 0; l < lvlRank; l++) { |
481 | | - const auto dlt = stt.getLvlType(l); |
482 | | - if (isLooseCompressedDLT(dlt)) |
483 | | - llvm_unreachable("TODO: Not yet implemented"); |
484 | | - if (isCompressedDLT(dlt)) { |
485 | | - // Compressed dimensions need a position cleanup for all entries |
486 | | - // that were not visited during the insertion pass. |
487 | | - // |
488 | | - // TODO: avoid cleanup and keep compressed scheme consistent at all |
489 | | - // times? |
490 | | - // |
491 | | - if (l > 0) { |
492 | | - Type posType = stt.getPosType(); |
493 | | - Value posMemRef = desc.getPosMemRef(l); |
494 | | - Value hi = desc.getPosMemSize(builder, loc, l); |
495 | | - Value zero = constantIndex(builder, loc, 0); |
496 | | - Value one = constantIndex(builder, loc, 1); |
497 | | - // Vector of only one, but needed by createFor's prototype. |
498 | | - SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)}; |
499 | | - scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
500 | | - Value i = loop.getInductionVar(); |
501 | | - Value oldv = loop.getRegionIterArg(0); |
502 | | - Value newv = genLoad(builder, loc, posMemRef, i); |
503 | | - Value posZero = constantZero(builder, loc, posType); |
504 | | - Value cond = builder.create<arith::CmpIOp>( |
505 | | - loc, arith::CmpIPredicate::eq, newv, posZero); |
506 | | - scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
507 | | - cond, /*else*/ true); |
508 | | - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
509 | | - genStore(builder, loc, oldv, posMemRef, i); |
510 | | - builder.create<scf::YieldOp>(loc, oldv); |
511 | | - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
512 | | - builder.create<scf::YieldOp>(loc, newv); |
513 | | - builder.setInsertionPointAfter(ifOp); |
514 | | - builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
515 | | - builder.setInsertionPointAfter(loop); |
516 | | - } |
517 | | - } else { |
518 | | - assert(isDenseDLT(dlt) || isSingletonDLT(dlt)); |
519 | | - } |
520 | | - } |
521 | | -} |
522 | | - |
523 | | -static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
524 | | - Value sz) { |
525 | | - auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType(); |
526 | | - return builder |
527 | | - .create<memref::SubViewOp>( |
528 | | - loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, |
529 | | - ValueRange{}, ValueRange{sz}, ValueRange{}, |
530 | | - ArrayRef<int64_t>{0}, // static offset |
531 | | - ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
532 | | - ArrayRef<int64_t>{1}) // static stride |
533 | | - .getResult(); |
534 | | -} |
535 | | - |
536 | | -static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) { |
537 | | - ReassociationIndices reassociation; |
538 | | - for (int i = 0, e = srcTp.getRank(); i < e; i++) |
539 | | - reassociation.push_back(i); |
540 | | - return reassociation; |
541 | | -} |
542 | | - |
543 | | -static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, |
544 | | - Type dstTp) { |
545 | | - if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) { |
546 | | - // Scalars can only be converted to 0-ranked tensors. |
547 | | - if (rtp.getRank() != 0) |
548 | | - return nullptr; |
549 | | - elem = genCast(builder, loc, elem, rtp.getElementType()); |
550 | | - return builder.create<tensor::FromElementsOp>(loc, rtp, elem); |
551 | | - } |
552 | | - return genCast(builder, loc, elem, dstTp); |
553 | | -} |
554 | | - |
555 | | -//===----------------------------------------------------------------------===// |
556 | | -// Codegen rules. |
557 | | -//===----------------------------------------------------------------------===// |
558 | | - |
559 | 560 | /// Sparse tensor storage conversion rule for returns. |
560 | 561 | class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { |
561 | 562 | public: |
|
0 commit comments