66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10+ #include " mlir/Dialect/Arith/Utils/Utils.h"
911#include " mlir/Dialect/Tensor/IR/Tensor.h"
1012#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1113#include " mlir/IR/PatternMatch.h"
14+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
1215#include " llvm/Support/Debug.h"
1316
1417using namespace mlir ;
@@ -210,6 +213,178 @@ struct BubbleUpExpandThroughParallelCollapse
210213 }
211214};
212215
216+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
217+ // / `tensor.expand_shape(tensor.extract_slice)`.
218+ // / For this transformation to be possible, the slice must be fully contiguous
219+ // / within each reassociation group of the expand_shape. If the transformation
220+ // / is not possible, or if the slice is rank reducting, the function returns
221+ // / failure.
222+ // /
223+ // / Example:
224+ // / ```
225+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
226+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
227+ // / %slice = tensor.extract_slice %reshape ...
228+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
229+ // /
230+ // / // The transformation is possible because each reassociation group has a
231+ // / // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
232+ // / // After the transformation:
233+ // /
234+ // / %slice = tensor.extract_slice %in ...
235+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
236+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
237+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
238+ // / ```
239+ // /
240+ // / Note - this pattern could be reworked to be a swap pattern between
241+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
242+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
243+ struct BubbleUpExpandShapeThroughExtractSlice
244+ : public OpRewritePattern<tensor::ExtractSliceOp> {
245+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
246+
247+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
248+ PatternRewriter &rewriter) const override {
249+ auto expandShapeOp =
250+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
251+ if (!expandShapeOp) {
252+ return rewriter.notifyMatchFailure (
253+ sliceOp, " slice source not produced by expand_shape" );
254+ }
255+
256+ if (!sliceOp.hasUnitStride ()) {
257+ return rewriter.notifyMatchFailure (sliceOp,
258+ " unsupported: non-unit stride" );
259+ }
260+
261+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
262+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
263+
264+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
265+ sizes.size ()) {
266+ return rewriter.notifyMatchFailure (sliceOp,
267+ " unimplemented: rank reducing slice" );
268+ }
269+
270+ // Helper variables and function for accumulating the new offset and length
271+ // values.
272+ Location loc = expandShapeOp->getLoc ();
273+ AffineExpr d0, d1, d2;
274+ bindDims (rewriter.getContext (), d0, d1, d2);
275+ // Multiply two integers.
276+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
277+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
278+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
279+ {v1, v2});
280+ };
281+
282+ SmallVector<OpFoldResult> outputShape =
283+ getMixedValues (expandShapeOp.getStaticOutputShape (),
284+ expandShapeOp.getOutputShape (), rewriter);
285+
286+ auto isZeroOffsetAndFullSize =
287+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
288+ if (!isConstantIntValue (offset, 0 ))
289+ return false ;
290+ FailureOr<bool > maybeEqual =
291+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
292+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
293+ };
294+
295+ // First verify that this is a full slice of the expanded tensor.
296+ for (const ReassociationIndices &indices :
297+ expandShapeOp.getReassociationIndices ()) {
298+ int64_t i = 0 ;
299+ int64_t e = indices.size ();
300+ // Find the first expanded dim after the first dim with non-unit extracted
301+ // size.
302+ for (; i < e; ++i) {
303+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
304+ // +1 to skip the first non-unit size dim.
305+ i++;
306+ break ;
307+ }
308+ }
309+
310+ // Verify that all subsequent dimensions extract the full size of the
311+ // source tensor.
312+ for (; i < e; ++i) {
313+ int64_t expandedDim = indices[i];
314+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
315+ outputShape[expandedDim])) {
316+ return rewriter.notifyMatchFailure (
317+ sliceOp, " Not a contiguous slice of the expanded tensor." );
318+ }
319+ }
320+ }
321+
322+ // Compute new offsets, lengths, and strides.
323+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
324+ for (const ReassociationIndices &indices :
325+ expandShapeOp.getReassociationIndices ()) {
326+ OpFoldResult newSize = rewriter.getIndexAttr (1 );
327+ SmallVector<OpFoldResult> basis, delinOffsets;
328+
329+ int64_t i = 0 ;
330+ int64_t e = indices.size ();
331+ // Offset = cumulative product of leading unit extracted dims.
332+ for (; i < e; ++i) {
333+ int64_t expandedDim = indices[i];
334+ if (!isConstantIntValue (sizes[expandedDim], 1 ))
335+ break ;
336+
337+ basis.push_back (outputShape[expandedDim]);
338+ delinOffsets.push_back (offsets[expandedDim]);
339+ }
340+
341+ if (i != e) {
342+ int64_t expandedDim = indices[i];
343+ basis.push_back (outputShape[expandedDim]);
344+ delinOffsets.push_back (offsets[expandedDim]);
345+ newSize = sizes[expandedDim];
346+ i++;
347+ }
348+
349+ for (; i < e; ++i) {
350+ OpFoldResult fullSize = outputShape[indices[i]];
351+ basis.push_back (fullSize);
352+ delinOffsets.push_back (rewriter.getIndexAttr (0 ));
353+ newSize = mul (newSize, fullSize);
354+ }
355+ SmallVector<Value> offsetVals =
356+ llvm::map_to_vector (delinOffsets, [&](OpFoldResult ofr) {
357+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
358+ });
359+ OpFoldResult newOffset =
360+ rewriter
361+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
362+ /* disjoint=*/ true )
363+ .getResult ();
364+ newOffsets.push_back (newOffset);
365+ newLengths.push_back (newSize);
366+
367+ // Only unit stride supported.
368+ newStrides.push_back (rewriter.getIndexAttr (1 ));
369+ }
370+
371+ // The shape of the result can be obtained from the sizes passed in.
372+ SmallVector<Value> dynDims;
373+ SmallVector<int64_t > shape;
374+ dispatchIndexOpFoldResults (sizes, dynDims, shape);
375+ RankedTensorType resultType = RankedTensorType::get (
376+ shape, expandShapeOp.getResultType ().getElementType ());
377+
378+ // Create a new ExtractSliceOp and ExpandShapeOp.
379+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
380+ loc, expandShapeOp.getSrc (), newOffsets, newLengths, newStrides);
381+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
382+ sliceOp, resultType, newSliceOp,
383+ expandShapeOp.getReassociationIndices (), sizes);
384+ return success ();
385+ }
386+ };
387+
213388} // namespace
214389
215390void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +402,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227402 RewritePatternSet &patterns) {
228403 patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229404}
405+
406+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
407+ RewritePatternSet &patterns) {
408+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
409+ }
0 commit comments