@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110110 }));
111111 }
112112
113- // Instantiate the tiled implementation of the operation.
113+ // / Instantiate the tiled implementation of the operation.
114114 FailureOr<TilingResult>
115115 getTiledImplementation (Operation *op, OpBuilder &b,
116116 ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
132132 return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
133133 }
134134
135- // Return the details of the output tile generated by the tiled
136- // implementation.
135+ void
136+ getMappedOffsetAndSize (LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137+ ArrayRef<OpFoldResult> offsets,
138+ ArrayRef<OpFoldResult> sizes,
139+ SmallVectorImpl<OpFoldResult> &mappedOffsets,
140+ SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141+ unsigned numLoops = linalgOp.getNumLoops ();
142+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
143+ mappedOffsets.resize (numLoops);
144+ mappedSizes.resize (numLoops);
145+ if (!indexingMap.isPermutation ()) {
146+ SmallVector<Range> iterationDomain =
147+ tilingInterfaceOp.getIterationDomain (b);
148+ for (const auto &&[index, value] : llvm::enumerate (iterationDomain)) {
149+ mappedOffsets[index] = value.offset ;
150+ mappedSizes[index] = value.size ;
151+ }
152+ }
153+ for (const auto &&[index, value] :
154+ llvm::enumerate (indexingMap.getResults ())) {
155+ unsigned dimPosition = cast<AffineDimExpr>(value).getPosition ();
156+ mappedOffsets[dimPosition] = offsets[index];
157+ mappedSizes[dimPosition] = sizes[index];
158+ }
159+ }
160+
161+ // / Return the details of the output tile generated by the tiled
162+ // / implementation.
163+ LogicalResult getIterationDomainTileFromOperandTile (
164+ Operation *op, OpBuilder &b, unsigned operandNumber,
165+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
167+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
168+ auto linalgOp = cast<LinalgOp>(op);
169+
170+ // Check that the indexing map used for the operand is a projected
171+ // permutation. This could be relaxed with a more general approach that can
172+ // map the offsets and sizes from the operand to iteration space tiles
173+ // (filling in full extent for dimensions not used to access the result).
174+ AffineMap indexingMap =
175+ linalgOp.getMatchingIndexingMap (&op->getOpOperand (operandNumber));
176+ if (!indexingMap.isProjectedPermutation ()) {
177+ return emitError (op->getLoc (),
178+ " unhandled get iter domain position when operand is not "
179+ " accessed using a permuted projection" );
180+ }
181+
182+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
183+ iterDomainOffsets, iterDomainSizes);
184+ return success ();
185+ }
186+
187+ // / Return the details of the output tile generated by the tiled
188+ // / implementation.
137189 LogicalResult
138190 getResultTilePosition (Operation *op, OpBuilder &b, unsigned resultNumber,
139191 ArrayRef<OpFoldResult> offsets,
140192 ArrayRef<OpFoldResult> sizes,
141- SmallVector <OpFoldResult> &resultOffsets,
142- SmallVector <OpFoldResult> &resultSizes) const {
193+ SmallVectorImpl <OpFoldResult> &resultOffsets,
194+ SmallVectorImpl <OpFoldResult> &resultSizes) const {
143195 Location loc = op->getLoc ();
144196 LinalgOp linalgOp = cast<LinalgOp>(op);
145197
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
160212 return success ();
161213 }
162214
215+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
216+ Operation *op, OpBuilder &b, unsigned operandNumber,
217+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219+ auto tilingInterfaceOp = cast<TilingInterface>(op);
220+ if (failed (tilingInterfaceOp.getIterationDomainTileFromOperandTile (
221+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222+ return emitError (
223+ op->getLoc (),
224+ " unable to obtain the iter domain position of the operation." );
225+ }
226+ return tilingInterfaceOp.getTiledImplementation (b, mappedOffsets,
227+ mappedSizes);
228+ }
229+
163230 FailureOr<TilingResult>
164231 generateResultTileValue (Operation *op, OpBuilder &b, unsigned resultNumber,
165232 ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
177244 " unhandled tiled implementation generation when result is not "
178245 " accessed using a permuted projection" );
179246 }
180-
181- auto numLoops = linalgOp.getNumLoops ();
247+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
249+ mappedOffsets, mappedSizes);
182250 auto tilingInterfaceOp = cast<TilingInterface>(op);
183- SmallVector<OpFoldResult> iterationTileOffsets (numLoops),
184- iterationTileSizes (numLoops);
185- if (!indexingMap.isPermutation ()) {
186- SmallVector<Range> iterationDomain =
187- tilingInterfaceOp.getIterationDomain (b);
188- for (const auto &range : llvm::enumerate (iterationDomain)) {
189- iterationTileOffsets[range.index ()] = range.value ().offset ;
190- iterationTileSizes[range.index ()] = range.value ().size ;
191- }
192- }
193- for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
194- unsigned dimPosition =
195- cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
196- iterationTileOffsets[dimPosition] = offsets[resultExpr.index ()];
197- iterationTileSizes[dimPosition] = sizes[resultExpr.index ()];
198- }
199-
200251 FailureOr<TilingResult> tilingResult =
201- tilingInterfaceOp.getTiledImplementation (b, iterationTileOffsets,
202- iterationTileSizes);
252+ tilingInterfaceOp.getTiledImplementation (b, mappedOffsets, mappedSizes);
253+
254+ if (failed (tilingResult))
255+ return failure ();
256+
203257 if (tilingResult->tiledOps .size () != 1 )
204258 return op->emitOpError (" failed to generate tiled implementation" );
205259
0 commit comments