@@ -73,56 +73,44 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
73
73
return success ();
74
74
}
75
75
76
- // / Specialization for ` linalg::GenericOp`.
77
- // / A pattern to convert Generic Linalg operations which work on tensors to
78
- // / use buffers. BufferPlacement pass should be later used to move
79
- // / Alloc operations to the correct positions and insert the missing Dealloc
80
- // / operations in the correct places.
81
- static void
82
- finalizeBufferAllocationForGenericOp (ConversionPatternRewriter &rewriter,
83
- GenericOp genericOp, ValueRange inputs,
84
- ValueRange outputs) {
85
- // Generate a new linalg operation that works on buffers.
86
- auto newGenericOp = rewriter. create <GenericOp>(
87
- genericOp. getLoc () ,
88
- /* resultTensorTypes =*/ llvm::None ,
89
- /* inputs= */ inputs ,
90
- /* outputs= */ outputs, genericOp.indexing_maps (),
91
- genericOp. iterator_types (), genericOp. docAttr (),
92
- genericOp. library_callAttr ());
93
-
94
- // Create a new block in the region of the new Generic Op.
95
- Block *oldBlock = genericOp. getBody ();
96
- Region &newRegion = newGenericOp. region ( );
97
- Block *newBlock = rewriter. createBlock (&newRegion, newRegion. begin (),
98
- oldBlock-> getArgumentTypes ());
99
-
100
- // Clone the body of the old block to the new block.
101
- BlockAndValueMapping mapping;
102
- mapping. map (oldBlock-> getArguments (), newBlock-> getArguments () );
103
-
104
- OpBuilder::InsertionGuard guard (rewriter);
105
- rewriter.setInsertionPointToEnd (newBlock );
106
- for ( auto &op : oldBlock-> getOperations ()) {
107
- Operation *clonedOp = rewriter. clone (op, mapping);
108
- mapping. map (op. getResults (), clonedOp-> getResults ()) ;
76
+ // / Create linalg op on buffers given the original tensor-based operation and
77
+ // / the buffers for the outputs.
78
+ LinalgOp
79
+ mlir::linalg::createLinalgOpOnBuffers (ConversionPatternRewriter &rewriter,
80
+ LinalgOp linalgOp, ValueRange inputs,
81
+ ValueRange outputs) {
82
+ if ( auto genericOp = mlir::dyn_cast<GenericOp>(*linalgOp)) {
83
+ // Generate a new linalg operation that works on buffers.
84
+ auto newGenericOp = rewriter. create <GenericOp>(
85
+ genericOp. getLoc (),
86
+ /* resultTensorTypes= */ llvm::None,
87
+ /* inputs= */ inputs ,
88
+ /* outputs =*/ outputs, genericOp. indexing_maps () ,
89
+ genericOp. iterator_types (), genericOp. docAttr () ,
90
+ genericOp.library_callAttr ());
91
+
92
+ // Create a new block in the region of the new Generic Op.
93
+ Block *oldBlock = genericOp. getBody ();
94
+ Region &newRegion = newGenericOp. region ();
95
+ Block *newBlock = rewriter. createBlock (&newRegion, newRegion. begin (),
96
+ oldBlock-> getArgumentTypes () );
97
+
98
+ // Clone the body of the old block to the new block.
99
+ BlockAndValueMapping mapping;
100
+ mapping. map (oldBlock-> getArguments (), newBlock-> getArguments ());
101
+
102
+ OpBuilder::InsertionGuard guard (rewriter );
103
+ rewriter. setInsertionPointToEnd (newBlock);
104
+ for ( auto &op : oldBlock-> getOperations ()) {
105
+ Operation *clonedOp = rewriter.clone (op, mapping );
106
+ mapping. map (op. getResults (), clonedOp-> getResults ());
107
+ }
108
+ return newGenericOp ;
109
109
}
110
-
111
- // Replace the results of the old op with the new output buffers.
112
- rewriter.replaceOp (genericOp, outputs);
113
- }
114
-
115
- // / Specialization for all other `linalg::LinalgOp`.
116
- static void finalizeBufferAllocation (ConversionPatternRewriter &rewriter,
117
- linalg::LinalgOp linalgOp,
118
- ValueRange inputs, ValueRange outputs) {
119
- assert (!isa<linalg::GenericOp>(linalgOp.getOperation ()));
120
110
SmallVector<Value, 8 > newOperands = inputs;
121
111
newOperands.append (outputs.begin (), outputs.end ());
122
- linalgOp.clone (rewriter, linalgOp.getLoc (),
123
- /* resultTypes=*/ ArrayRef<Type>{}, newOperands);
124
- // Replace the results of the old op with the new output buffers.
125
- rewriter.replaceOp (linalgOp, outputs);
112
+ return linalgOp.clone (rewriter, linalgOp.getLoc (),
113
+ /* resultTypes=*/ ArrayRef<Type>{}, newOperands);
126
114
}
127
115
128
116
// ===----------------------------------------------------------------------===//
@@ -218,15 +206,9 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
218
206
return op.emitOpError ()
219
207
<< " Failed to allocate buffers for tensor results." ;
220
208
}
221
-
222
- // Delegate to the linalg generic pattern.
223
- if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
224
- finalizeBufferAllocationForGenericOp (rewriter, genericOp,
225
- adaptor.inputs (), newOutputBuffers);
226
- return success ();
227
- }
228
-
229
- finalizeBufferAllocation (rewriter, op, adaptor.inputs (), newOutputBuffers);
209
+ createLinalgOpOnBuffers (rewriter, op, adaptor.inputs (), newOutputBuffers);
210
+ // Replace the results of the old op with the new output buffers.
211
+ rewriter.replaceOp (op, newOutputBuffers);
230
212
return success ();
231
213
}
232
214
};
0 commit comments