|
42 | 42 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
43 | 43 | #include "llvm/ADT/STLExtras.h"
|
44 | 44 | #include "llvm/ADT/ScopeExit.h"
|
| 45 | +#include "llvm/ADT/SmallPtrSet.h" |
45 | 46 | #include "llvm/ADT/TypeSwitch.h"
|
46 | 47 | #include "llvm/Support/DebugLog.h"
|
47 | 48 | #include "llvm/Support/LogicalResult.h"
|
@@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
|
273 | 274 | // BufferizeToAllocationOp
|
274 | 275 | //===----------------------------------------------------------------------===//
|
275 | 276 |
|
276 |
| -void transform::BufferizeToAllocationOp::build(OpBuilder &b, |
277 |
| - OperationState &result, |
278 |
| - Value target, |
279 |
| - Attribute memorySpace) { |
280 |
| - SmallVector<Type> resultTypes; |
281 |
| - resultTypes.push_back(b.getType<transform::AnyValueType>()); |
282 |
| - resultTypes.push_back(b.getType<transform::AnyOpType>()); |
283 |
| - return build(b, result, |
284 |
| - /*resultTypes=*/resultTypes, |
285 |
| - /*target=*/target, |
286 |
| - /*memory_space=*/memorySpace); |
287 |
| -} |
288 |
| - |
289 |
| -void transform::BufferizeToAllocationOp::build(OpBuilder &b, |
290 |
| - OperationState &result, |
291 |
| - Value target, |
292 |
| - int64_t memorySpace) { |
293 |
| - SmallVector<Type> resultTypes; |
294 |
| - resultTypes.push_back(b.getType<transform::AnyValueType>()); |
295 |
| - resultTypes.push_back(b.getType<transform::AnyOpType>()); |
296 |
| - return build(b, result, |
297 |
| - /*resultTypes=*/resultTypes, |
298 |
| - /*target=*/target, |
299 |
| - /*memory_space=*/b.getI64IntegerAttr(memorySpace)); |
300 |
| -} |
301 |
| - |
302 | 277 | namespace {
|
303 | 278 | class NewOpsListener : public RewriterBase::ForwardingListener {
|
304 | 279 | public:
|
@@ -408,6 +383,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
|
408 | 383 | return success();
|
409 | 384 | }
|
410 | 385 |
|
| 386 | +//===----------------------------------------------------------------------===// |
| 387 | +// PromoteTensorOp |
| 388 | +//===----------------------------------------------------------------------===// |
| 389 | + |
| 390 | +/// Return true if the operand may be read from by its owner. This is currently |
| 391 | +/// very conservative and only looks inside linalg operations to prevent |
| 392 | +/// unintentional data loss. |
| 393 | +static bool mayBeRead(OpOperand &operand) { |
| 394 | + auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner()); |
| 395 | + |
| 396 | + // Be conservative about ops we cannot analyze deeper. |
| 397 | + if (!linalgOp) |
| 398 | + return true; |
| 399 | + |
| 400 | + // Look inside linalg ops. |
| 401 | + Value blockArgument = linalgOp.getMatchingBlockArgument(&operand); |
| 402 | + return !blockArgument.use_empty(); |
| 403 | +} |
| 404 | + |
| 405 | +/// Return true if the value may be read through any of its uses. |
| 406 | +static bool mayBeRead(Value value) { |
| 407 | + // If the value has a reference semantics, it |
| 408 | + // may be read through any alias... |
| 409 | + if (!isa<TensorType, FloatType, IntegerType>(value.getType())) |
| 410 | + return true; |
| 411 | + return llvm::any_of(value.getUses(), |
| 412 | + static_cast<bool (&)(OpOperand &)>(mayBeRead)); |
| 413 | +} |
| 414 | + |
| 415 | +DiagnosedSilenceableFailure |
| 416 | +transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, |
| 417 | + transform::TransformResults &results, |
| 418 | + transform::TransformState &state) { |
| 419 | + SmallVector<Value> promoted; |
| 420 | + for (Value tensor : state.getPayloadValues(getTensor())) { |
| 421 | + auto type = dyn_cast<RankedTensorType>(tensor.getType()); |
| 422 | + if (!type) { |
| 423 | + return emitSilenceableError() << "non-tensor type: " << tensor; |
| 424 | + } |
| 425 | + |
| 426 | + Operation *definingOp = tensor.getDefiningOp(); |
| 427 | + if (definingOp) |
| 428 | + rewriter.setInsertionPointAfter(definingOp); |
| 429 | + else |
| 430 | + rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner()); |
| 431 | + |
| 432 | + // Check this before we emit operations using this value. |
| 433 | + bool needsMaterialization = mayBeRead(tensor); |
| 434 | + |
| 435 | + SmallVector<Value> dynamicDims; |
| 436 | + llvm::SmallPtrSet<Operation *, 4> preservedOps; |
| 437 | + for (auto [pos, dim] : llvm::enumerate(type.getShape())) { |
| 438 | + if (!ShapedType::isDynamic(dim)) |
| 439 | + continue; |
| 440 | + Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); |
| 441 | + auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); |
| 442 | + preservedOps.insert(dimOp); |
| 443 | + dynamicDims.push_back(dimOp); |
| 444 | + } |
| 445 | + auto allocation = rewriter.create<bufferization::AllocTensorOp>( |
| 446 | + tensor.getLoc(), type, dynamicDims); |
| 447 | + // Set memory space if provided. |
| 448 | + if (getMemorySpaceAttr()) |
| 449 | + allocation.setMemorySpaceAttr(getMemorySpaceAttr()); |
| 450 | + Value allocated = allocation; |
| 451 | + |
| 452 | + // Only insert a materialization (typically bufferizes to a copy) when the |
| 453 | + // value may be read from. |
| 454 | + if (needsMaterialization) { |
| 455 | + auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( |
| 456 | + tensor.getLoc(), tensor, allocated); |
| 457 | + preservedOps.insert(copy); |
| 458 | + promoted.push_back(copy.getResult()); |
| 459 | + } else { |
| 460 | + promoted.push_back(allocated); |
| 461 | + } |
| 462 | + rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps); |
| 463 | + } |
| 464 | + results.setValues(cast<OpResult>(getPromoted()), promoted); |
| 465 | + return DiagnosedSilenceableFailure::success(); |
| 466 | +} |
| 467 | + |
| 468 | +void transform::PromoteTensorOp::getEffects( |
| 469 | + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 470 | + transform::onlyReadsHandle(getTensorMutable(), effects); |
| 471 | + transform::producesHandle(getOperation()->getOpResults(), effects); |
| 472 | + transform::modifiesPayload(effects); |
| 473 | +} |
| 474 | + |
411 | 475 | //===----------------------------------------------------------------------===//
|
412 | 476 | // DecomposeOp
|
413 | 477 | //===----------------------------------------------------------------------===//
|
|
0 commit comments