Skip to content

Commit 5ce4278

Browse files
pifon2amemfrob
authored andcommitted
[mlir][linalg] Expose function to create op on buffers during bufferization.
Differential Revision: https://reviews.llvm.org/D109140
1 parent 33f85e0 commit 5ce4278

File tree

2 files changed

+44
-56
lines changed

2 files changed

+44
-56
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
8080
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
8181
RewritePatternSet &patterns);
8282

83+
/// Create linalg op on buffers given the original tensor-based operation and
84+
/// the buffers for the outputs.
85+
LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
86+
LinalgOp linalgOp, ValueRange inputs,
87+
ValueRange outputs);
88+
8389
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
8490
/// tensors.
8591
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 38 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -73,56 +73,44 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
7373
return success();
7474
}
7575

76-
/// Specialization for `linalg::GenericOp`.
77-
/// A pattern to convert Generic Linalg operations which work on tensors to
78-
/// use buffers. BufferPlacement pass should be later used to move
79-
/// Alloc operations to the correct positions and insert the missing Dealloc
80-
/// operations in the correct places.
81-
static void
82-
finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
83-
GenericOp genericOp, ValueRange inputs,
84-
ValueRange outputs) {
85-
// Generate a new linalg operation that works on buffers.
86-
auto newGenericOp = rewriter.create<GenericOp>(
87-
genericOp.getLoc(),
88-
/*resultTensorTypes=*/llvm::None,
89-
/*inputs=*/inputs,
90-
/*outputs=*/outputs, genericOp.indexing_maps(),
91-
genericOp.iterator_types(), genericOp.docAttr(),
92-
genericOp.library_callAttr());
93-
94-
// Create a new block in the region of the new Generic Op.
95-
Block *oldBlock = genericOp.getBody();
96-
Region &newRegion = newGenericOp.region();
97-
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
98-
oldBlock->getArgumentTypes());
99-
100-
// Clone the body of the old block to the new block.
101-
BlockAndValueMapping mapping;
102-
mapping.map(oldBlock->getArguments(), newBlock->getArguments());
103-
104-
OpBuilder::InsertionGuard guard(rewriter);
105-
rewriter.setInsertionPointToEnd(newBlock);
106-
for (auto &op : oldBlock->getOperations()) {
107-
Operation *clonedOp = rewriter.clone(op, mapping);
108-
mapping.map(op.getResults(), clonedOp->getResults());
76+
/// Create linalg op on buffers given the original tensor-based operation and
77+
/// the buffers for the outputs.
78+
LinalgOp
79+
mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
80+
LinalgOp linalgOp, ValueRange inputs,
81+
ValueRange outputs) {
82+
if (auto genericOp = mlir::dyn_cast<GenericOp>(*linalgOp)) {
83+
// Generate a new linalg operation that works on buffers.
84+
auto newGenericOp = rewriter.create<GenericOp>(
85+
genericOp.getLoc(),
86+
/*resultTensorTypes=*/llvm::None,
87+
/*inputs=*/inputs,
88+
/*outputs=*/outputs, genericOp.indexing_maps(),
89+
genericOp.iterator_types(), genericOp.docAttr(),
90+
genericOp.library_callAttr());
91+
92+
// Create a new block in the region of the new Generic Op.
93+
Block *oldBlock = genericOp.getBody();
94+
Region &newRegion = newGenericOp.region();
95+
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
96+
oldBlock->getArgumentTypes());
97+
98+
// Clone the body of the old block to the new block.
99+
BlockAndValueMapping mapping;
100+
mapping.map(oldBlock->getArguments(), newBlock->getArguments());
101+
102+
OpBuilder::InsertionGuard guard(rewriter);
103+
rewriter.setInsertionPointToEnd(newBlock);
104+
for (auto &op : oldBlock->getOperations()) {
105+
Operation *clonedOp = rewriter.clone(op, mapping);
106+
mapping.map(op.getResults(), clonedOp->getResults());
107+
}
108+
return newGenericOp;
109109
}
110-
111-
// Replace the results of the old op with the new output buffers.
112-
rewriter.replaceOp(genericOp, outputs);
113-
}
114-
115-
/// Specialization for all other `linalg::LinalgOp`.
116-
static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
117-
linalg::LinalgOp linalgOp,
118-
ValueRange inputs, ValueRange outputs) {
119-
assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
120110
SmallVector<Value, 8> newOperands = inputs;
121111
newOperands.append(outputs.begin(), outputs.end());
122-
linalgOp.clone(rewriter, linalgOp.getLoc(),
123-
/*resultTypes=*/ArrayRef<Type>{}, newOperands);
124-
// Replace the results of the old op with the new output buffers.
125-
rewriter.replaceOp(linalgOp, outputs);
112+
return linalgOp.clone(rewriter, linalgOp.getLoc(),
113+
/*resultTypes=*/ArrayRef<Type>{}, newOperands);
126114
}
127115

128116
//===----------------------------------------------------------------------===//
@@ -218,15 +206,9 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
218206
return op.emitOpError()
219207
<< "Failed to allocate buffers for tensor results.";
220208
}
221-
222-
// Delegate to the linalg generic pattern.
223-
if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
224-
finalizeBufferAllocationForGenericOp(rewriter, genericOp,
225-
adaptor.inputs(), newOutputBuffers);
226-
return success();
227-
}
228-
229-
finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
209+
createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
210+
// Replace the results of the old op with the new output buffers.
211+
rewriter.replaceOp(op, newOutputBuffers);
230212
return success();
231213
}
232214
};

0 commit comments

Comments
 (0)