Skip to content

Commit a7a4c16

Browse files
authored
[mlir][Vector] Support efficient shape cast lowering for n-D vectors (#123497)
This PR implements a generalization of the existing more efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V.
1 parent 3b2b7ec commit a7a4c16

File tree

2 files changed

+108
-101
lines changed

2 files changed

+108
-101
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 90 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,43 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1514
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/Arith/Utils/Utils.h"
17-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1815
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
#include "mlir/Dialect/SCF/IR/SCF.h"
20-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21-
#include "mlir/Dialect/Utils/IndexingUtils.h"
22-
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2316
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2417
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2518
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2619
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
2820
#include "mlir/IR/BuiltinTypes.h"
29-
#include "mlir/IR/ImplicitLocOpBuilder.h"
3021
#include "mlir/IR/Location.h"
31-
#include "mlir/IR/Matchers.h"
3222
#include "mlir/IR/PatternMatch.h"
3323
#include "mlir/IR/TypeUtilities.h"
34-
#include "mlir/Interfaces/VectorInterfaces.h"
3524

3625
#define DEBUG_TYPE "vector-shape-cast-lowering"
3726

3827
using namespace mlir;
3928
using namespace mlir::vector;
4029

30+
/// Increments n-D `indices` by `step` starting from the innermost dimension.
31+
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
32+
int step = 1) {
33+
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34+
assert(indices[dim] < vecType.getDimSize(dim) &&
35+
"Indices are out of bound");
36+
indices[dim] += step;
37+
if (indices[dim] < vecType.getDimSize(dim))
38+
break;
39+
40+
indices[dim] = 0;
41+
step = 1;
42+
}
43+
}
44+
4145
namespace {
42-
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
43-
/// vectors progressively on the way to target llvm.matrix intrinsics.
44-
/// This iterates over the most major dimension of the 2-D vector and performs
45-
/// rewrites into:
46-
/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
47-
class ShapeCastOp2DDownCastRewritePattern
46+
/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
47+
/// vectors progressively. This iterates over the n-1 major dimensions of the
48+
/// n-D vector and performs rewrites into:
49+
/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
50+
class ShapeCastOpNDDownCastRewritePattern
4851
: public OpRewritePattern<vector::ShapeCastOp> {
4952
public:
5053
using OpRewritePattern::OpRewritePattern;
@@ -53,35 +56,52 @@ class ShapeCastOp2DDownCastRewritePattern
5356
PatternRewriter &rewriter) const override {
5457
auto sourceVectorType = op.getSourceVectorType();
5558
auto resultVectorType = op.getResultVectorType();
56-
5759
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
5860
return failure();
5961

60-
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
62+
int64_t srcRank = sourceVectorType.getRank();
63+
int64_t resRank = resultVectorType.getRank();
64+
if (srcRank < 2 || resRank != 1)
6165
return failure();
6266

67+
// Compute the number of 1-D vector elements involved in the reshape.
68+
int64_t numElts = 1;
69+
for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70+
numElts *= sourceVectorType.getDimSize(dim);
71+
6372
auto loc = op.getLoc();
64-
Value desc = rewriter.create<arith::ConstantOp>(
73+
SmallVector<int64_t> srcIdx(srcRank - 1, 0);
74+
SmallVector<int64_t> resIdx(resRank, 0);
75+
int64_t extractSize = sourceVectorType.getShape().back();
76+
Value result = rewriter.create<arith::ConstantOp>(
6577
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
66-
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
67-
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
68-
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
69-
desc = rewriter.create<vector::InsertStridedSliceOp>(
70-
loc, vec, desc,
71-
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
78+
79+
// Compute the indices of each 1-D vector element of the source extraction
80+
// and destination slice insertion and generate such instructions.
81+
for (int64_t i = 0; i < numElts; ++i) {
82+
if (i != 0) {
83+
incIdx(srcIdx, sourceVectorType, /*step=*/1);
84+
incIdx(resIdx, resultVectorType, /*step=*/extractSize);
85+
}
86+
87+
Value extract =
88+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
89+
result = rewriter.create<vector::InsertStridedSliceOp>(
90+
loc, extract, result,
91+
/*offsets=*/resIdx, /*strides=*/1);
7292
}
73-
rewriter.replaceOp(op, desc);
93+
94+
rewriter.replaceOp(op, result);
7495
return success();
7596
}
7697
};
7798

78-
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
79-
/// vectors progressively.
80-
/// This iterates over the most major dimension of the 2-D vector and performs
81-
/// rewrites into:
82-
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
99+
/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
100+
/// vectors progressively. This iterates over the n-1 major dimension of the n-D
101+
/// vector and performs rewrites into:
102+
/// vector.extract_strided_slice from 1-D + vector.insert into n-D
83103
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
84-
class ShapeCastOp2DUpCastRewritePattern
104+
class ShapeCastOpNDUpCastRewritePattern
85105
: public OpRewritePattern<vector::ShapeCastOp> {
86106
public:
87107
using OpRewritePattern::OpRewritePattern;
@@ -90,43 +110,43 @@ class ShapeCastOp2DUpCastRewritePattern
90110
PatternRewriter &rewriter) const override {
91111
auto sourceVectorType = op.getSourceVectorType();
92112
auto resultVectorType = op.getResultVectorType();
93-
94113
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
95114
return failure();
96115

97-
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
116+
int64_t srcRank = sourceVectorType.getRank();
117+
int64_t resRank = resultVectorType.getRank();
118+
if (srcRank != 1 || resRank < 2)
98119
return failure();
99120

121+
// Compute the number of 1-D vector elements involved in the reshape.
122+
int64_t numElts = 1;
123+
for (int64_t dim = 0; dim < resRank - 1; ++dim)
124+
numElts *= resultVectorType.getDimSize(dim);
125+
126+
// Compute the indices of each 1-D vector element of the source slice
127+
// extraction and destination insertion and generate such instructions.
100128
auto loc = op.getLoc();
101-
Value desc = rewriter.create<arith::ConstantOp>(
129+
SmallVector<int64_t> srcIdx(srcRank, 0);
130+
SmallVector<int64_t> resIdx(resRank - 1, 0);
131+
int64_t extractSize = resultVectorType.getShape().back();
132+
Value result = rewriter.create<arith::ConstantOp>(
102133
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
103-
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
104-
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
105-
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
106-
loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
107-
/*sizes=*/mostMinorVectorSize,
134+
for (int64_t i = 0; i < numElts; ++i) {
135+
if (i != 0) {
136+
incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
137+
incIdx(resIdx, resultVectorType, /*step=*/1);
138+
}
139+
140+
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
141+
loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
108142
/*strides=*/1);
109-
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
143+
result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
110144
}
111-
rewriter.replaceOp(op, desc);
145+
rewriter.replaceOp(op, result);
112146
return success();
113147
}
114148
};
115149

116-
static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
117-
int dimIdx, int initialStep = 1) {
118-
int step = initialStep;
119-
for (int d = dimIdx; d >= 0; d--) {
120-
idx[d] += step;
121-
if (idx[d] >= tp.getDimSize(d)) {
122-
idx[d] = 0;
123-
step = 1;
124-
} else {
125-
break;
126-
}
127-
}
128-
}
129-
130150
// We typically should not lower general shape cast operations into data
131151
// movement instructions, since the assumption is that these casts are
132152
// optimized away during progressive lowering. For completeness, however,
@@ -145,18 +165,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
145165
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
146166
return failure();
147167

148-
// Special case 2D / 1D lowerings with better implementations.
149-
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
168+
// Special case for n-D / 1-D lowerings with better implementations.
150169
int64_t srcRank = sourceVectorType.getRank();
151170
int64_t resRank = resultVectorType.getRank();
152-
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
171+
if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
153172
return failure();
154173

155174
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
156175
// extract/insert chains.
157-
// TODO: consider evolving the semantics to only allow 1D source or dest and
158-
// drop this potentially very expensive lowering.
159-
// Compute number of elements involved in the reshape.
160176
int64_t numElts = 1;
161177
for (int64_t r = 0; r < srcRank; r++)
162178
numElts *= sourceVectorType.getDimSize(r);
@@ -166,14 +182,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
166182
// x[0,1,0] = y[0,2]
167183
// etc., incrementing the two index vectors "row-major"
168184
// within the source and result shape.
169-
SmallVector<int64_t> srcIdx(srcRank);
170-
SmallVector<int64_t> resIdx(resRank);
185+
SmallVector<int64_t> srcIdx(srcRank, 0);
186+
SmallVector<int64_t> resIdx(resRank, 0);
171187
Value result = rewriter.create<arith::ConstantOp>(
172188
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
173189
for (int64_t i = 0; i < numElts; i++) {
174190
if (i != 0) {
175-
incIdx(srcIdx, sourceVectorType, srcRank - 1);
176-
incIdx(resIdx, resultVectorType, resRank - 1);
191+
incIdx(srcIdx, sourceVectorType);
192+
incIdx(resIdx, resultVectorType);
177193
}
178194

179195
Value extract;
@@ -252,7 +268,7 @@ class ScalableShapeCastOpRewritePattern
252268
// have a single trailing scalable dimension. This is because there are no
253269
// legal representation of other scalable types in LLVM (and likely won't be
254270
// soon). There are also (currently) no operations that can index or extract
255-
// from >= 2D scalable vectors or scalable vectors of fixed vectors.
271+
// from >= 2-D scalable vectors or scalable vectors of fixed vectors.
256272
if (!isTrailingDimScalable(sourceVectorType) ||
257273
!isTrailingDimScalable(resultVectorType)) {
258274
return failure();
@@ -278,8 +294,8 @@ class ScalableShapeCastOpRewritePattern
278294
Value result = rewriter.create<arith::ConstantOp>(
279295
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
280296

281-
SmallVector<int64_t> srcIdx(srcRank);
282-
SmallVector<int64_t> resIdx(resRank);
297+
SmallVector<int64_t> srcIdx(srcRank, 0);
298+
SmallVector<int64_t> resIdx(resRank, 0);
283299

284300
// TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
285301
// once D150000 lands.
@@ -334,8 +350,8 @@ class ScalableShapeCastOpRewritePattern
334350

335351
// 4. Increment the insert/extract indices, stepping by minExtractionSize
336352
// for the trailing dimensions.
337-
incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
338-
incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
353+
incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
354+
incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
339355
}
340356

341357
rewriter.replaceOp(op, result);
@@ -352,8 +368,8 @@ class ScalableShapeCastOpRewritePattern
352368

353369
void mlir::vector::populateVectorShapeCastLoweringPatterns(
354370
RewritePatternSet &patterns, PatternBenefit benefit) {
355-
patterns.add<ShapeCastOp2DDownCastRewritePattern,
356-
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
371+
patterns.add<ShapeCastOpNDDownCastRewritePattern,
372+
ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
357373
ScalableShapeCastOpRewritePattern>(patterns.getContext(),
358374
benefit);
359375
}

mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
22

33
// CHECK-LABEL: func @nop_shape_cast
44
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
@@ -82,19 +82,16 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
8282
// CHECK-LABEL: func @shape_cast_3d1d
8383
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
8484
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
85-
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32>
86-
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
87-
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32>
88-
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
89-
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32>
90-
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
91-
// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32>
92-
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
93-
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32>
94-
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
95-
// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32>
96-
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
97-
// CHECK: return %[[T11]] : vector<6xf32>
85+
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
86+
// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
87+
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
88+
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
89+
// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
90+
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
91+
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
92+
// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
93+
// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
94+
// CHECK: return %[[T5]] : vector<6xf32>
9895

9996
func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
10097
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
@@ -104,19 +101,13 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
104101
// CHECK-LABEL: func @shape_cast_1d3d
105102
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
106103
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
107-
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32>
108-
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
109-
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32>
110-
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
111-
// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32>
112-
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
113-
// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32>
114-
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
115-
// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32>
116-
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
117-
// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32>
118-
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
119-
// CHECK: return %[[T11]] : vector<2x1x3xf32>
104+
// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
105+
// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
106+
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
107+
// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
108+
// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
109+
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
110+
// CHECK: return %[[T3]] : vector<2x1x3xf32>
120111

121112
func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
122113
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>

0 commit comments

Comments
 (0)