66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10+ #include " mlir/Dialect/Arith/Utils/Utils.h"
911#include " mlir/Dialect/Tensor/IR/Tensor.h"
1012#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1113#include " mlir/IR/PatternMatch.h"
14+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
1215#include " llvm/Support/Debug.h"
16+ #include " llvm/Support/LogicalResult.h"
1317
1418using namespace mlir ;
1519using namespace mlir ::tensor;
@@ -210,6 +214,214 @@ struct BubbleUpExpandThroughParallelCollapse
210214 }
211215};
212216
217+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
218+ // / `tensor.expand_shape(tensor.extract_slice)`.
219+ // /
220+ // / For this transformation to be possible, the slice must be fully contiguous
221+ // / within each reassociation group of the expand_shape. A slice is defined as
222+ // / fully contiguous within a reassociation group if after flattening the
223+ // / reassociation group to a single 1D range, then the slice taken out of the
224+ // / group could be defined as a single contiguous subrange within that range.
225+ // /
226+ // / Rank reducing slices are not supported.
227+ // /
228+ // / Example:
229+ // / The transformation is possible because each reassociation group has a
230+ // / contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
231+ // / ```
232+ // / BEFORE:
233+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
234+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
235+ // / %slice = tensor.extract_slice %reshape ...
236+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
237+ // /
238+ // / AFTER:
239+ // / %slice = tensor.extract_slice %in ...
240+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
241+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
242+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
243+ // / ```
244+ // /
245+ // / Note - this pattern could be extended to be a swap pattern between
246+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
247+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
248+ struct BubbleUpExpandShapeThroughExtractSlice
249+ : public OpRewritePattern<tensor::ExtractSliceOp> {
250+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
251+
252+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
253+ PatternRewriter &rewriter) const override {
254+ auto expandShapeOp =
255+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
256+
257+ if (checkPreconditionForBubbleUpExtractSlice (sliceOp, expandShapeOp,
258+ rewriter)
259+ .failed ())
260+ return failure ();
261+
262+ // The tensor.extract_slice before applying the pattern works on the result
263+ // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
264+ // referring to the state before applying the pattern are named with the
265+ // prefix "expanded", and ones referring to the state after applying the
266+ // pattern are named with the prefix "collapsed".
267+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets ();
268+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes ();
269+ SmallVector<OpFoldResult> expandedShape =
270+ getMixedValues (expandShapeOp.getStaticOutputShape (),
271+ expandShapeOp.getOutputShape (), rewriter);
272+
273+ // Helper variables and function for accumulating the size values.
274+ Location loc = expandShapeOp->getLoc ();
275+ AffineExpr d0, d1, d2;
276+ bindDims (rewriter.getContext (), d0, d1, d2);
277+ // Multiply two integers.
278+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
279+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
280+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
281+ {v1, v2});
282+ };
283+
284+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
285+ // The new tensor.extract_slice will work on a tensor that has has a rank of
286+ // ReassociationIndices.size(). In the loop a single offset, size, and
287+ // stride value is computed per reassociation group.
288+ SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
289+ collapsedStrides;
290+ for (const ReassociationIndices &indices :
291+ expandShapeOp.getReassociationIndices ()) {
292+ // collapsedSize will hold the size of the single dim that represents the
293+ // reassociation group in the non expanded tensor.
294+ OpFoldResult collapsedSize = rewriter.getIndexAttr (1 );
295+ // The reassocGroupSizes and reassocGroupOffsets are used to create an
296+ // affine.linearize_index op to linearize the single offset value required
297+ // for this reassociation group.
298+ SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
299+
300+ for (long expandedDim : indices) {
301+ // reassocGroupSizes and reassocGroupOffsets can be obtained directly
302+ // from the expanded state, but the collapsed size requires calculation
303+ // as it did not previously exist.
304+ reassocGroupSizes.push_back (expandedShape[expandedDim]);
305+ reassocGroupOffsets.push_back (expandedOffsets[expandedDim]);
306+ collapsedSize = mul (collapsedSize, expandedSizes[expandedDim]);
307+ }
308+
309+ SmallVector<Value> offsetVals =
310+ llvm::map_to_vector (reassocGroupOffsets, [&](OpFoldResult ofr) {
311+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
312+ });
313+ OpFoldResult collapsedOffset =
314+ rewriter
315+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals,
316+ reassocGroupSizes,
317+ /* disjoint=*/ true )
318+ .getResult ();
319+ collapsedOffsets.push_back (collapsedOffset);
320+ collapsedSizes.push_back (collapsedSize);
321+
322+ // Only unit stride is supported.
323+ collapsedStrides.push_back (rewriter.getIndexAttr (1 ));
324+ }
325+
326+ // The shape of the result can be obtained from the sizes passed in.
327+ SmallVector<Value> dynDims;
328+ SmallVector<int64_t > shape;
329+ dispatchIndexOpFoldResults (expandedSizes, dynDims, shape);
330+ RankedTensorType resultType = RankedTensorType::get (
331+ shape, expandShapeOp.getResultType ().getElementType ());
332+
333+ // Create a new ExtractSliceOp and ExpandShapeOp.
334+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
335+ loc, expandShapeOp.getSrc (), collapsedOffsets, collapsedSizes,
336+ collapsedStrides);
337+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
338+ sliceOp, resultType, newSliceOp,
339+ expandShapeOp.getReassociationIndices (), expandedSizes);
340+ return success ();
341+ }
342+
343+ // Helper function to check if all the required conditions for the
344+ // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
345+ // met.
346+ LogicalResult
347+ checkPreconditionForBubbleUpExtractSlice (tensor::ExtractSliceOp sliceOp,
348+ tensor::ExpandShapeOp expandShapeOp,
349+ PatternRewriter &rewriter) const {
350+
351+ if (!expandShapeOp) {
352+ return rewriter.notifyMatchFailure (
353+ sliceOp, " tensor.extract_slice source not produced by expand_shape" );
354+ }
355+
356+ if (!sliceOp.hasUnitStride ()) {
357+ return rewriter.notifyMatchFailure (
358+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
359+ " be supported in this transformation." );
360+ }
361+
362+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
363+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
364+
365+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
366+ sizes.size ()) {
367+ return rewriter.notifyMatchFailure (sliceOp,
368+ " unimplemented: rank reducing slice" );
369+ }
370+
371+ SmallVector<OpFoldResult> outputShape =
372+ getMixedValues (expandShapeOp.getStaticOutputShape (),
373+ expandShapeOp.getOutputShape (), rewriter);
374+
375+ std::function<bool (OpFoldResult, OpFoldResult, OpFoldResult)>
376+ isZeroOffsetAndFullSize =
377+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
378+ if (!isConstantIntValue (offset, 0 ))
379+ return false ;
380+ FailureOr<bool > maybeEqual =
381+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
382+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
383+ };
384+
385+ // Check that the slice is contiguous within each reassociation group.
386+ // The slice is contiguous only if after the first dimension where a non
387+ // unit slice is taken, the slice size on all subsequent dimensions of the
388+ // group is equal to the entire size of the dimension.
389+ // Examples of contiguous slices:
390+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
391+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
392+ // Examples of non contiguous slices:
393+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
394+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
395+ for (const ReassociationIndices &indices :
396+ expandShapeOp.getReassociationIndices ()) {
397+ int64_t i = 0 ;
398+ int64_t e = indices.size ();
399+ // Find the first expanded dim after the first dim with non-unit extracted
400+ // size.
401+ for (; i < e; ++i) {
402+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
403+ // +1 to skip the first non-unit size dim.
404+ i++;
405+ break ;
406+ }
407+ }
408+
409+ // Verify that all subsequent dimensions extract the full size of the
410+ // source tensor.
411+ for (; i < e; ++i) {
412+ int64_t expandedDim = indices[i];
413+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
414+ outputShape[expandedDim])) {
415+ return rewriter.notifyMatchFailure (
416+ sliceOp, " Not a contiguous slice of the expanded tensor." );
417+ }
418+ }
419+ }
420+
421+ return success ();
422+ }
423+ };
424+
213425} // namespace
214426
215427void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +439,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227439 RewritePatternSet &patterns) {
228440 patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229441}
442+
443+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
444+ RewritePatternSet &patterns) {
445+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
446+ }
0 commit comments