Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 90 additions & 74 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,43 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"

#define DEBUG_TYPE "vector-shape-cast-lowering"

using namespace mlir;
using namespace mlir::vector;

/// Increments n-D `indices` by `step` starting from the innermost dimension.
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
int step = 1) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
assert(indices[dim] < vecType.getDimSize(dim) &&
"Indices are out of bound");
indices[dim] += step;
if (indices[dim] < vecType.getDimSize(dim))
break;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we add an assertion that indices[dim] == vecType.getDimSize(dim)? It looks weird to me when it happens. Assertion is a sanity check in this case.

indices[dim] = 0;
step = 1;
}
}

namespace {
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
/// vectors progressively on the way to target llvm.matrix intrinsics.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
class ShapeCastOp2DDownCastRewritePattern
/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
/// vectors progressively. This iterates over the n-1 major dimensions of the
/// n-D vector and performs rewrites into:
/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
class ShapeCastOpNDDownCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -53,35 +56,52 @@ class ShapeCastOp2DDownCastRewritePattern
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();

if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();

if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if (srcRank < 2 || resRank != 1)
return failure();

// Compute the number of 1-D vector elements involved in the reshape.
int64_t numElts = 1;
for (int64_t dim = 0; dim < srcRank - 1; ++dim)
numElts *= sourceVectorType.getDimSize(dim);

auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
SmallVector<int64_t> srcIdx(srcRank - 1, 0);
SmallVector<int64_t> resIdx(resRank, 0);
int64_t extractSize = sourceVectorType.getShape().back();
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
desc = rewriter.create<vector::InsertStridedSliceOp>(
loc, vec, desc,
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);

// Compute the indices of each 1-D vector element of the source extraction
// and destination slice insertion and generate such instructions.
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, /*step=*/1);
incIdx(resIdx, resultVectorType, /*step=*/extractSize);
}

Value extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, extract, result,
/*offsets=*/resIdx, /*strides=*/1);
}
rewriter.replaceOp(op, desc);

rewriter.replaceOp(op, result);
return success();
}
};

/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
/// vectors progressively.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
/// vectors progressively. This iterates over the n-1 major dimension of the n-D
/// vector and performs rewrites into:
/// vector.extract_strided_slice from 1-D + vector.insert into n-D
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
class ShapeCastOp2DUpCastRewritePattern
class ShapeCastOpNDUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -90,43 +110,43 @@ class ShapeCastOp2DUpCastRewritePattern
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();

if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();

if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if (srcRank != 1 || resRank < 2)
return failure();

// Compute the number of 1-D vector elements involved in the reshape.
int64_t numElts = 1;
for (int64_t dim = 0; dim < resRank - 1; ++dim)
numElts *= resultVectorType.getDimSize(dim);

// Compute the indices of each 1-D vector element of the source slice
// extraction and destination insertion and generate such instructions.
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank - 1, 0);
int64_t extractSize = resultVectorType.getShape().back();
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
/*sizes=*/mostMinorVectorSize,
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
incIdx(resIdx, resultVectorType, /*step=*/1);
}

Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
/*strides=*/1);
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, desc);
rewriter.replaceOp(op, result);
return success();
}
};

static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
int dimIdx, int initialStep = 1) {
int step = initialStep;
for (int d = dimIdx; d >= 0; d--) {
idx[d] += step;
if (idx[d] >= tp.getDimSize(d)) {
idx[d] = 0;
step = 1;
} else {
break;
}
}
}

// We typically should not lower general shape cast operations into data
// movement instructions, since the assumption is that these casts are
// optimized away during progressive lowering. For completeness, however,
Expand All @@ -145,18 +165,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();

// Special case 2D / 1D lowerings with better implementations.
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
// Special case for n-D / 1-D lowerings with better implementations.
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
return failure();

// Generic ShapeCast lowering path goes all the way down to unrolled scalar
// extract/insert chains.
// TODO: consider evolving the semantics to only allow 1D source or dest and
// drop this potentially very expensive lowering.
// Compute number of elements involved in the reshape.
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
numElts *= sourceVectorType.getDimSize(r);
Expand All @@ -166,14 +182,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
SmallVector<int64_t> srcIdx(srcRank);
SmallVector<int64_t> resIdx(resRank);
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank, 0);
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
for (int64_t i = 0; i < numElts; i++) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, srcRank - 1);
incIdx(resIdx, resultVectorType, resRank - 1);
incIdx(srcIdx, sourceVectorType);
incIdx(resIdx, resultVectorType);
}

Value extract;
Expand Down Expand Up @@ -252,7 +268,7 @@ class ScalableShapeCastOpRewritePattern
// have a single trailing scalable dimension. This is because there are no
// legal representation of other scalable types in LLVM (and likely won't be
// soon). There are also (currently) no operations that can index or extract
// from >= 2D scalable vectors or scalable vectors of fixed vectors.
// from >= 2-D scalable vectors or scalable vectors of fixed vectors.
if (!isTrailingDimScalable(sourceVectorType) ||
!isTrailingDimScalable(resultVectorType)) {
return failure();
Expand All @@ -278,8 +294,8 @@ class ScalableShapeCastOpRewritePattern
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));

SmallVector<int64_t> srcIdx(srcRank);
SmallVector<int64_t> resIdx(resRank);
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank, 0);

// TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
// once D150000 lands.
Expand Down Expand Up @@ -334,8 +350,8 @@ class ScalableShapeCastOpRewritePattern

// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
}

rewriter.replaceOp(op, result);
Expand All @@ -352,8 +368,8 @@ class ScalableShapeCastOpRewritePattern

void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
patterns.add<ShapeCastOpNDDownCastRewritePattern,
ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
ScalableShapeCastOpRewritePattern>(patterns.getContext(),
benefit);
}
45 changes: 18 additions & 27 deletions mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s

// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
Expand Down Expand Up @@ -82,19 +82,16 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
// CHECK-LABEL: func @shape_cast_3d1d
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
// CHECK: return %[[T11]] : vector<6xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
// CHECK: return %[[T5]] : vector<6xf32>

func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
Expand All @@ -104,19 +101,13 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
// CHECK-LABEL: func @shape_cast_1d3d
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
// CHECK: return %[[T11]] : vector<2x1x3xf32>
// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
// CHECK: return %[[T3]] : vector<2x1x3xf32>

func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
Expand Down
Loading