66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10+ #include " mlir/Dialect/Arith/Utils/Utils.h"
911#include " mlir/Dialect/Tensor/IR/Tensor.h"
1012#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1113#include " mlir/IR/PatternMatch.h"
14+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
1215#include " llvm/Support/Debug.h"
16+ #include " llvm/Support/LogicalResult.h"
1317
1418using namespace mlir ;
1519using namespace mlir ::tensor;
@@ -210,6 +214,200 @@ struct BubbleUpExpandThroughParallelCollapse
210214 }
211215};
212216
217+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
218+ // / `tensor.expand_shape(tensor.extract_slice)`.
219+ // / For this transformation to be possible, the slice must be fully contiguous
220+ // / within each reassociation group of the expand_shape. If the transformation
221+ // / is not possible, or if the slice is rank reducing, the function returns
222+ // / failure.
223+ // /
224+ // / Example:
225+ // / ```
226+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
227+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
228+ // / %slice = tensor.extract_slice %reshape ...
229+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
230+ // /
231+ // / // The transformation is possible because each reassociation group has a
232+ // / // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
233+ // / // After the transformation:
234+ // /
235+ // / %slice = tensor.extract_slice %in ...
236+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
237+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
238+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
239+ // / ```
240+ // /
241+ // / Note - this pattern could be reworked to be a swap pattern between
242+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
243+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
244+ struct BubbleUpExpandShapeThroughExtractSlice
245+ : public OpRewritePattern<tensor::ExtractSliceOp> {
246+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
247+
248+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
249+ PatternRewriter &rewriter) const override {
250+ auto expandShapeOp =
251+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
252+
253+ if (checkPreconditionForBubbleUpExtractSlice (sliceOp, expandShapeOp,
254+ rewriter)
255+ .failed ())
256+ return failure ();
257+
258+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
259+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
260+ SmallVector<OpFoldResult> outputShape =
261+ getMixedValues (expandShapeOp.getStaticOutputShape (),
262+ expandShapeOp.getOutputShape (), rewriter);
263+
264+ // Helper variables and function for accumulating the new offset and length
265+ // values.
266+ Location loc = expandShapeOp->getLoc ();
267+ AffineExpr d0, d1, d2;
268+ bindDims (rewriter.getContext (), d0, d1, d2);
269+ // Multiply two integers.
270+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
271+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
272+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
273+ {v1, v2});
274+ };
275+
276+ // Compute new offsets, lengths, and strides.
277+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
278+ for (const ReassociationIndices &indices :
279+ expandShapeOp.getReassociationIndices ()) {
280+ OpFoldResult newSize = rewriter.getIndexAttr (1 );
281+ SmallVector<OpFoldResult> basis, delinOffsets;
282+
283+ int64_t i = 0 ;
284+ int64_t e = indices.size ();
285+ // Offset = cumulative product of leading unit extracted dims.
286+ for (; i < e; ++i) {
287+ int64_t expandedDim = indices[i];
288+ if (!isConstantIntValue (sizes[expandedDim], 1 ))
289+ break ;
290+
291+ basis.push_back (outputShape[expandedDim]);
292+ delinOffsets.push_back (offsets[expandedDim]);
293+ }
294+
295+ if (i != e) {
296+ int64_t expandedDim = indices[i];
297+ basis.push_back (outputShape[expandedDim]);
298+ delinOffsets.push_back (offsets[expandedDim]);
299+ newSize = sizes[expandedDim];
300+ i++;
301+ }
302+
303+ for (; i < e; ++i) {
304+ OpFoldResult fullSize = outputShape[indices[i]];
305+ basis.push_back (fullSize);
306+ delinOffsets.push_back (rewriter.getIndexAttr (0 ));
307+ newSize = mul (newSize, fullSize);
308+ }
309+ SmallVector<Value> offsetVals =
310+ llvm::map_to_vector (delinOffsets, [&](OpFoldResult ofr) {
311+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
312+ });
313+ OpFoldResult newOffset =
314+ rewriter
315+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
316+ /* disjoint=*/ true )
317+ .getResult ();
318+ newOffsets.push_back (newOffset);
319+ newLengths.push_back (newSize);
320+
321+ // Only unit stride supported.
322+ newStrides.push_back (rewriter.getIndexAttr (1 ));
323+ }
324+
325+ // The shape of the result can be obtained from the sizes passed in.
326+ SmallVector<Value> dynDims;
327+ SmallVector<int64_t > shape;
328+ dispatchIndexOpFoldResults (sizes, dynDims, shape);
329+ RankedTensorType resultType = RankedTensorType::get (
330+ shape, expandShapeOp.getResultType ().getElementType ());
331+
332+ // Create a new ExtractSliceOp and ExpandShapeOp.
333+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
334+ loc, expandShapeOp.getSrc (), newOffsets, newLengths, newStrides);
335+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
336+ sliceOp, resultType, newSliceOp,
337+ expandShapeOp.getReassociationIndices (), sizes);
338+ return success ();
339+ }
340+
341+ LogicalResult
342+ checkPreconditionForBubbleUpExtractSlice (tensor::ExtractSliceOp sliceOp,
343+ tensor::ExpandShapeOp expandShapeOp,
344+ PatternRewriter &rewriter) const {
345+
346+ if (!expandShapeOp) {
347+ return rewriter.notifyMatchFailure (
348+ sliceOp, " tensor.extract_slice source not produced by expand_shape" );
349+ }
350+
351+ if (!sliceOp.hasUnitStride ()) {
352+ return rewriter.notifyMatchFailure (
353+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
354+ " be supported in this transformation." );
355+ }
356+
357+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
358+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
359+
360+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
361+ sizes.size ()) {
362+ return rewriter.notifyMatchFailure (sliceOp,
363+ " unimplemented: rank reducing slice" );
364+ }
365+
366+ SmallVector<OpFoldResult> outputShape =
367+ getMixedValues (expandShapeOp.getStaticOutputShape (),
368+ expandShapeOp.getOutputShape (), rewriter);
369+
370+ std::function<bool (OpFoldResult, OpFoldResult, OpFoldResult)>
371+ isZeroOffsetAndFullSize =
372+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
373+ if (!isConstantIntValue (offset, 0 ))
374+ return false ;
375+ FailureOr<bool > maybeEqual =
376+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
377+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
378+ };
379+
380+ // First verify that this is a full slice of the expanded tensor.
381+ for (const ReassociationIndices &indices :
382+ expandShapeOp.getReassociationIndices ()) {
383+ int64_t i = 0 ;
384+ int64_t e = indices.size ();
385+ // Find the first expanded dim after the first dim with non-unit extracted
386+ // size.
387+ for (; i < e; ++i) {
388+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
389+ // +1 to skip the first non-unit size dim.
390+ i++;
391+ break ;
392+ }
393+ }
394+
395+ // Verify that all subsequent dimensions extract the full size of the
396+ // source tensor.
397+ for (; i < e; ++i) {
398+ int64_t expandedDim = indices[i];
399+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
400+ outputShape[expandedDim])) {
401+ return rewriter.notifyMatchFailure (
402+ sliceOp, " Not a contiguous slice of the expanded tensor." );
403+ }
404+ }
405+ }
406+
407+ return success ();
408+ }
409+ };
410+
213411} // namespace
214412
215413void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +425,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227425 RewritePatternSet &patterns) {
228426 patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229427}
428+
429+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
430+ RewritePatternSet &patterns) {
431+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
432+ }
0 commit comments