66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
10+ #include " mlir/Dialect/Arith/Utils/Utils.h"
911#include " mlir/Dialect/Tensor/IR/Tensor.h"
1012#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1113#include " mlir/IR/PatternMatch.h"
14+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
1215#include " llvm/Support/Debug.h"
16+ #include " llvm/Support/LogicalResult.h"
1317
1418using namespace mlir ;
1519using namespace mlir ::tensor;
@@ -210,6 +214,217 @@ struct BubbleUpExpandThroughParallelCollapse
210214 }
211215};
212216
217+ // / Converts `tensor.extract_slice(tensor.expand_shape)` to
218+ // / `tensor.expand_shape(tensor.extract_slice)`.
219+ // / For this transformation to be possible, the slice must be fully contiguous
220+ // / within each reassociation group of the expand_shape.
221+ // / A slice is defined as fully contiguous within a reassociation group if after
222+ // / flattening the reassociation group to a single 1D range, then the slice
223+ // / taken out of the group could be defined as a single contiguous subrange
224+ // / within that range.
225+ // / If the transformation is not possible, or if the slice is rank reducing, the
226+ // / function returns failure.
227+ // /
228+ // / Example:
229+ // / ```
230+ // / %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
231+ // / tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
232+ // / %slice = tensor.extract_slice %reshape ...
233+ // / tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
234+ // /
235+ // / // The transformation is possible because each reassociation group has a
236+ // / // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
237+ // / // After the transformation:
238+ // /
239+ // / %slice = tensor.extract_slice %in ...
240+ // / tensor<8x16x32xf32> to tensor<8x5x4xf32>
241+ // / %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
242+ // / tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
243+ // / ```
244+ // /
245+ // / Note - this pattern could be reworked to be a swap pattern between
246+ // / `tensor.expand_shape` and `tensor.extract_slice`, but is currently
247+ // / implemented only as a bubble up pattern for `tensor.extract_slice`.
248+ struct BubbleUpExpandShapeThroughExtractSlice
249+ : public OpRewritePattern<tensor::ExtractSliceOp> {
250+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
251+
252+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
253+ PatternRewriter &rewriter) const override {
254+ auto expandShapeOp =
255+ sliceOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
256+
257+ if (checkPreconditionForBubbleUpExtractSlice (sliceOp, expandShapeOp,
258+ rewriter)
259+ .failed ())
260+ return failure ();
261+
262+ // The tensor.extract_slice before applying the pattern works on the result
263+ // of the tensor.expand_shape, so variables referring to the state before
264+ // applying the pattern are named with the prefix "expanded", and ones
265+ // referring to the state after applying the pattern are named with the
266+ // prefix "collapsed".
267+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets ();
268+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes ();
269+ SmallVector<OpFoldResult> expandedShape =
270+ getMixedValues (expandShapeOp.getStaticOutputShape (),
271+ expandShapeOp.getOutputShape (), rewriter);
272+
273+ // Helper variables and function for accumulating the size values.
274+ Location loc = expandShapeOp->getLoc ();
275+ AffineExpr d0, d1, d2;
276+ bindDims (rewriter.getContext (), d0, d1, d2);
277+ // Multiply two integers.
278+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
279+ auto mulMap = AffineMap::get (2 , 0 , {d0 * d1});
280+ return affine::makeComposedFoldedAffineApply (rewriter, loc, mulMap,
281+ {v1, v2});
282+ };
283+
284+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
285+ // The new tensor.extract_slice will work on a tensor that has has a rank of
286+ // ReassociationIndices.size(). In the loop a single offset, size, and
287+ // stride value is computed per reassociation group.
288+ SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
289+ collapsedStrides;
290+ for (const ReassociationIndices &indices :
291+ expandShapeOp.getReassociationIndices ()) {
292+ // collapsedSize will hold the size of the single dim that represents the
293+ // reassociation group in the non expanded tensor.
294+ OpFoldResult collapsedSize = rewriter.getIndexAttr (1 );
295+ // The basis and delinOffsets are used to create an affine.linearize_index
296+ // op to linearize the single offset value required for this reassociation
297+ // group.
298+ // basis holds the full sizes of the reassociation group dimensions
299+ // of the expanded tensor.
300+ // delinOffsets as in "delinearized offsets", holds the offsets within the
301+ // reassociation group dimensions of the expanded tensor.
302+ SmallVector<OpFoldResult> basis, delinOffsets;
303+
304+ for (long expandedDim : indices) {
305+ // basis and delinOffsets can be obtained directly from the expanded
306+ // state, but the collapsed size requires calculation as it did not
307+ // previously exist.
308+ basis.push_back (expandedShape[expandedDim]);
309+ delinOffsets.push_back (expandedOffsets[expandedDim]);
310+ collapsedSize = mul (collapsedSize, expandedSizes[expandedDim]);
311+ }
312+
313+ SmallVector<Value> offsetVals =
314+ llvm::map_to_vector (delinOffsets, [&](OpFoldResult ofr) {
315+ return getValueOrCreateConstantIndexOp (rewriter, loc, ofr);
316+ });
317+ OpFoldResult collapsedOffset =
318+ rewriter
319+ .create <affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
320+ /* disjoint=*/ true )
321+ .getResult ();
322+ collapsedOffsets.push_back (collapsedOffset);
323+ collapsedSizes.push_back (collapsedSize);
324+
325+ // Only unit stride supported.
326+ collapsedStrides.push_back (rewriter.getIndexAttr (1 ));
327+ }
328+
329+ // The shape of the result can be obtained from the sizes passed in.
330+ SmallVector<Value> dynDims;
331+ SmallVector<int64_t > shape;
332+ dispatchIndexOpFoldResults (expandedSizes, dynDims, shape);
333+ RankedTensorType resultType = RankedTensorType::get (
334+ shape, expandShapeOp.getResultType ().getElementType ());
335+
336+ // Create a new ExtractSliceOp and ExpandShapeOp.
337+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
338+ loc, expandShapeOp.getSrc (), collapsedOffsets, collapsedSizes,
339+ collapsedStrides);
340+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
341+ sliceOp, resultType, newSliceOp,
342+ expandShapeOp.getReassociationIndices (), expandedSizes);
343+ return success ();
344+ }
345+
346+ // Helper function to check if all the required conditions for the
347+ // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
348+ // met.
349+ LogicalResult
350+ checkPreconditionForBubbleUpExtractSlice (tensor::ExtractSliceOp sliceOp,
351+ tensor::ExpandShapeOp expandShapeOp,
352+ PatternRewriter &rewriter) const {
353+
354+ if (!expandShapeOp) {
355+ return rewriter.notifyMatchFailure (
356+ sliceOp, " tensor.extract_slice source not produced by expand_shape" );
357+ }
358+
359+ if (!sliceOp.hasUnitStride ()) {
360+ return rewriter.notifyMatchFailure (
361+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
362+ " be supported in this transformation." );
363+ }
364+
365+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets ();
366+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes ();
367+
368+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
369+ sizes.size ()) {
370+ return rewriter.notifyMatchFailure (sliceOp,
371+ " unimplemented: rank reducing slice" );
372+ }
373+
374+ SmallVector<OpFoldResult> outputShape =
375+ getMixedValues (expandShapeOp.getStaticOutputShape (),
376+ expandShapeOp.getOutputShape (), rewriter);
377+
378+ std::function<bool (OpFoldResult, OpFoldResult, OpFoldResult)>
379+ isZeroOffsetAndFullSize =
380+ [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
381+ if (!isConstantIntValue (offset, 0 ))
382+ return false ;
383+ FailureOr<bool > maybeEqual =
384+ ValueBoundsConstraintSet::areEqual (sliceSize, size);
385+ return llvm::succeeded (maybeEqual) && maybeEqual.value ();
386+ };
387+
388+ // Check that the slice is contiguous within each reassociation group.
389+ // The slice is contiguous only if after the first dimension where a non
390+ // unit slice is taken, the slice size on all subsequent dimensions of the
391+ // group is equal to the entire size of the dimension.
392+ // Examples of contiguous slices:
393+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
394+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
395+ // Examples of non contiguous slices:
396+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
397+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
398+ for (const ReassociationIndices &indices :
399+ expandShapeOp.getReassociationIndices ()) {
400+ int64_t i = 0 ;
401+ int64_t e = indices.size ();
402+ // Find the first expanded dim after the first dim with non-unit extracted
403+ // size.
404+ for (; i < e; ++i) {
405+ if (!isConstantIntValue (sizes[indices[i]], 1 )) {
406+ // +1 to skip the first non-unit size dim.
407+ i++;
408+ break ;
409+ }
410+ }
411+
412+ // Verify that all subsequent dimensions extract the full size of the
413+ // source tensor.
414+ for (; i < e; ++i) {
415+ int64_t expandedDim = indices[i];
416+ if (!isZeroOffsetAndFullSize (offsets[expandedDim], sizes[expandedDim],
417+ outputShape[expandedDim])) {
418+ return rewriter.notifyMatchFailure (
419+ sliceOp, " Not a contiguous slice of the expanded tensor." );
420+ }
421+ }
422+ }
423+
424+ return success ();
425+ }
426+ };
427+
213428} // namespace
214429
215430void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -227,3 +442,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227442 RewritePatternSet &patterns) {
228443 patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
229444}
445+
446+ void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
447+ RewritePatternSet &patterns) {
448+ patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
449+ }
0 commit comments