Skip to content

Commit 8ab3695

Browse files
authored
Move already supported vector ops to VectorLinearize (#712)
1 parent 57e245e commit 8ab3695

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed

lib/Transforms/VectorLinearize.cpp

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,211 @@
1212
///
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/IR/BuiltinOps.h"
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/Support/LogicalResult.h"
19+
#include "llvm/ADT/ArrayRef.h"
20+
#include "llvm/Transforms/Utils/AddDiscriminators.h"
21+
#include <cstdint>
1522
#include <imex/Transforms/Passes.h>
1623

1724
#include <mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h>
1825
#include <mlir/Pass/Pass.h>
1926
#include <mlir/Transforms/DialectConversion.h>
27+
#include <numeric>
2028

2129
namespace imex {
2230
#define GEN_PASS_DEF_VECTORLINEARIZE
2331
#include "imex/Transforms/Passes.h.inc"
2432
} // namespace imex
2533

2634
namespace {
35+
36+
struct VectorExtractStridedSliceConversion final
37+
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
38+
using mlir::OpConversionPattern<
39+
mlir::vector::ExtractStridedSliceOp>::OpConversionPattern;
40+
41+
mlir::LogicalResult
42+
matchAndRewrite(mlir::vector::ExtractStridedSliceOp extractOp,
43+
OpAdaptor adaptor,
44+
mlir::ConversionPatternRewriter &rewriter) const override {
45+
auto dstType = getTypeConverter()->convertType(extractOp.getType());
46+
auto loc = extractOp.getLoc();
47+
if (!dstType)
48+
return rewriter.notifyMatchFailure(loc, "cannot convert type.");
49+
50+
if (extractOp.getVector().getType().isScalable() ||
51+
dstType.cast<mlir::VectorType>().isScalable())
52+
return rewriter.notifyMatchFailure(loc,
53+
"scalable vectors are not supported.");
54+
55+
auto offsets = extractOp.getOffsets().getValue();
56+
auto sizes = extractOp.getSizes().getValue();
57+
auto strides = extractOp.getStrides().getValue();
58+
59+
if (!mlir::isConstantIntValue(strides[0], 1))
60+
return rewriter.notifyMatchFailure(
61+
extractOp, "Strided slice with stride != 1 is not supported.");
62+
63+
mlir::Value srcVector = adaptor.getVector();
64+
65+
// if kD offsets are specified for nd source vector (n > k), the granularity
66+
// of the extraction is greater than 1. In this case last (n-k) dimensions
67+
// form the extraction granularity. example : %0 =
68+
// vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
69+
// strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
70+
// here, extraction granularity is 8.
71+
int64_t extractSliceLen = 1;
72+
auto n = extractOp.getSourceVectorType().getRank();
73+
auto k = (int64_t)offsets.size();
74+
if (n > k) {
75+
for (unsigned i = 0; i < n - k; i++) {
76+
extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
77+
}
78+
}
79+
80+
// get total number of extracted slices
81+
int64_t nExtractedSlices = 1;
82+
for (auto size : sizes) {
83+
nExtractedSlices *= size.cast<mlir::IntegerAttr>().getInt();
84+
}
85+
86+
// compute the strides of the source vector considering first k dimensions
87+
llvm::SmallVector<int64_t, 4> sourceStrides(k, extractSliceLen);
88+
for (int i = k - 2; i >= 0; --i) {
89+
sourceStrides[i] = sourceStrides[i + 1] *
90+
extractOp.getSourceVectorType().getShape()[i + 1];
91+
}
92+
// final shuffle indices has nExtractedElems * extractSliceLen elements
93+
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * extractSliceLen);
94+
// compute the strides of the extracted kD vector
95+
llvm::SmallVector<int64_t, 4> extractedStrides(k, 1);
96+
// compute extractedStrides
97+
for (int i = k - 2; i >= 0; --i) {
98+
extractedStrides[i] = extractedStrides[i + 1] *
99+
sizes[i + 1].cast<mlir::IntegerAttr>().getInt();
100+
}
101+
// iterate over all extracted slices from 0 to nExtractedElems-1
102+
// and compute the multi-dimensional index and the corresponding linearized
103+
// index within the source vector
104+
for (int64_t i = 0; i < nExtractedSlices; ++i) {
105+
int64_t index = i;
106+
// compute the corresponding multi-dimensional index
107+
llvm::SmallVector<int64_t, 4> multiDimIndex(k, 0);
108+
for (int64_t j = 0; j < k; ++j) {
109+
multiDimIndex[j] = (index / extractedStrides[j]);
110+
index -= multiDimIndex[j] * extractedStrides[j];
111+
}
112+
// compute the corresponding linearized index in the source vector
113+
// i.e. shift the multiDimIndex by the offsets
114+
int64_t linearizedIndex = 0;
115+
for (int64_t j = 0; j < k; ++j) {
116+
linearizedIndex +=
117+
(offsets[j].cast<mlir::IntegerAttr>().getInt() + multiDimIndex[j]) *
118+
sourceStrides[j];
119+
}
120+
// fill the indices array form linearizedIndex to linearizedIndex +
121+
// sliceLen
122+
for (int64_t j = 0; j < extractSliceLen; ++j) {
123+
indices[i * extractSliceLen + j] = linearizedIndex + j;
124+
}
125+
}
126+
// perform a shuffle to extract the kD vector
127+
rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
128+
extractOp, dstType, srcVector, srcVector,
129+
rewriter.getI64ArrayAttr(indices));
130+
131+
return mlir::success();
132+
}
133+
};
134+
135+
struct VectorShffleOpConversion final
136+
: public mlir::OpConversionPattern<mlir::vector::ShuffleOp> {
137+
using mlir::OpConversionPattern<mlir::vector::ShuffleOp>::OpConversionPattern;
138+
139+
mlir::LogicalResult
140+
matchAndRewrite(mlir::vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
141+
mlir::ConversionPatternRewriter &rewriter) const override {
142+
auto dstType = getTypeConverter()->convertType(shuffleOp.getType());
143+
auto loc = shuffleOp.getLoc();
144+
if (!dstType)
145+
return rewriter.notifyMatchFailure(loc, "cannot convert type.");
146+
147+
auto vec1 = adaptor.getV1();
148+
auto vec2 = adaptor.getV2();
149+
150+
int shuffleSliceLen = 1;
151+
int rank = shuffleOp.getV1().getType().getRank();
152+
153+
// if rank > 1, we need to do the shuffle in the granularity of slices
154+
// instead of scalars. Size of the slice is equal to the rank-1 innermost
155+
// dims. Mask of the shuffle op specifies which slice to take from the
156+
// outermost dim.
157+
if (rank > 1) {
158+
auto shape = shuffleOp.getV1().getType().getShape();
159+
for (unsigned i = 1; i < shape.size(); i++) {
160+
shuffleSliceLen *= shape[i];
161+
}
162+
}
163+
164+
auto mask = shuffleOp.getMask();
165+
auto totalSize = mask.size() * shuffleSliceLen;
166+
167+
llvm::SmallVector<int64_t, 2> indices(totalSize);
168+
for (auto [i, value] :
169+
llvm::enumerate(mask.getAsValueRange<mlir::IntegerAttr>())) {
170+
171+
int64_t v = value.getZExtValue();
172+
std::iota(indices.begin() + shuffleSliceLen * i,
173+
indices.begin() + shuffleSliceLen * (i + 1),
174+
shuffleSliceLen * v);
175+
}
176+
177+
rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
178+
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
179+
180+
return mlir::success();
181+
}
182+
};
183+
184+
struct VectorExtractOpConversion final
185+
: public mlir::OpConversionPattern<mlir::vector::ExtractOp> {
186+
using OpConversionPattern::OpConversionPattern;
187+
mlir::LogicalResult
188+
matchAndRewrite(mlir::vector::ExtractOp extractOp, OpAdaptor adaptor,
189+
mlir::ConversionPatternRewriter &rewriter) const override {
190+
auto dstTy = getTypeConverter()->convertType(extractOp.getType());
191+
if (!dstTy)
192+
return rewriter.notifyMatchFailure(extractOp, "cannot convert type.");
193+
194+
// dynamic position is not supported
195+
if (extractOp.hasDynamicPosition())
196+
return rewriter.notifyMatchFailure(extractOp,
197+
"dynamic position is not supported.");
198+
199+
auto shape = extractOp.getVector().getType().getShape();
200+
auto size = extractOp.getVector().getType().getNumElements();
201+
202+
// compute linearized offset
203+
int64_t linearizedOffset = 0;
204+
auto offsets = extractOp.getStaticPosition();
205+
for (auto [i, off] : llvm::enumerate(offsets)) {
206+
size /= shape[i];
207+
linearizedOffset += offsets[i] * size;
208+
}
209+
210+
llvm::SmallVector<int64_t, 2> indices(size);
211+
std::iota(indices.begin(), indices.end(), linearizedOffset);
212+
rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
213+
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
214+
rewriter.getI64ArrayAttr(indices));
215+
216+
return mlir::success();
217+
}
218+
};
219+
27220
struct VectorLinearizePass final
28221
: public imex::impl::VectorLinearizeBase<VectorLinearizePass> {
29222

@@ -34,6 +227,14 @@ struct VectorLinearizePass final
34227
mlir::RewritePatternSet patterns(context);
35228
mlir::ConversionTarget target(*context);
36229

230+
target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>([&](mlir::Operation
231+
*op) {
232+
return op->getResult(0).getType().cast<mlir::VectorType>().getRank() == 1;
233+
});
234+
235+
patterns.add<VectorExtractStridedSliceConversion, VectorShffleOpConversion,
236+
VectorExtractOpConversion>(typeConverter, context);
237+
37238
typeConverter.addConversion([](mlir::Type type) { return type; });
38239
mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
39240
typeConverter, patterns, target);

test/Transforms/vector-linearize.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,74 @@ func.func @test_const_novector() -> i32 {
2727
%0 = arith.constant 42 : i32
2828
return %0 : i32
2929
}
30+
31+
// -----
32+
// CHECK-LABEL: test_extract_strided_slice
33+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<8x16xf32>) -> vector<8x8xf32>
34+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<8x16xf32> to vector<128xf32>
35+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
36+
// CHECK: [8, 9, 10, 11, 12, 13, 14, 15,
37+
// CHECK: 24, 25, 26, 27, 28, 29, 30, 31,
38+
// CHECK: 40, 41, 42, 43, 44, 45, 46, 47,
39+
// CHECK: 56, 57, 58, 59, 60, 61, 62, 63,
40+
// CHECK: 72, 73, 74, 75, 76, 77, 78, 79,
41+
// CHECK: 88, 89, 90, 91, 92, 93, 94, 95,
42+
// CHECK: 104, 105, 106, 107, 108, 109, 110, 111,
43+
// CHECK: 120, 121, 122, 123, 124, 125, 126, 127] : vector<128xf32>, vector<128xf32>
44+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<8x8xf32>
45+
// CHECK: return %[[RES]] : vector<8x8xf32>
46+
func.func @test_extract_strided_slice_1(%arg0 : vector<8x16xf32>) -> vector<8x8xf32> {
47+
%0 = vector.extract_strided_slice %arg0 { sizes = [8, 8], strides = [1, 1], offsets = [0, 8]}
48+
: vector<8x16xf32> to vector<8x8xf32>
49+
return %0 : vector<8x8xf32>
50+
}
51+
52+
// -----
53+
// CHECK-LABEL: test_extract_strided_slice_2
54+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x32x8xf32>) -> vector<1x8x8xf32>
55+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x32x8xf32> to vector<512xf32>
56+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
57+
// CHECK: [448, 449, 450, 451, 452, 453, 454, 455,
58+
// CHECK: 456, 457, 458, 459, 460, 461, 462, 463,
59+
// CHECK: 464, 465, 466, 467, 468, 469, 470, 471,
60+
// CHECK: 472, 473, 474, 475, 476, 477, 478, 479,
61+
// CHECK: 480, 481, 482, 483, 484, 485, 486, 487,
62+
// CHECK: 488, 489, 490, 491, 492, 493, 494, 495,
63+
// CHECK: 496, 497, 498, 499, 500, 501, 502, 503,
64+
// CHECK: 504, 505, 506, 507, 508, 509, 510, 511] : vector<512xf32>, vector<512xf32>
65+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<1x8x8xf32>
66+
// CHECK: return %[[RES]] : vector<1x8x8xf32>
67+
func.func @test_extract_strided_slice_2(%arg0 : vector<2x32x8xf32>) -> vector<1x8x8xf32> {
68+
%0 = vector.extract_strided_slice %arg0 { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] }
69+
: vector<2x32x8xf32> to vector<1x8x8xf32>
70+
return %0 : vector<1x8x8xf32>
71+
}
72+
73+
// -----
74+
// CHECK-LABEL: test_vector_shuffle
75+
// CHECK-SAME: (%[[ORIG_ARG1:.*]]: vector<4x4xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) -> vector<8x4xf32> {
76+
// CHECK: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x4xf32> to vector<16xf32>
77+
// CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32>
78+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG1]], %[[ARG2]]
79+
// CHECK: [0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23,
80+
// CHECK: 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
81+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32>
82+
// CHECK: return %[[RES]] : vector<8x4xf32>
83+
func.func @test_vector_shuffle(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>) -> vector<8x4xf32> {
84+
%0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x4xf32>, vector<4x4xf32>
85+
return %0 : vector<8x4xf32>
86+
}
87+
88+
// -----
89+
// CHECK-LABEL: test_vector_extract
90+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x4xf32>) -> vector<8x4xf32>
91+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x4xf32> to vector<64xf32>
92+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
93+
// CHECK: [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
94+
// CHECK: 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<64xf32>
95+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32>
96+
// CHECK: return %[[RES]] : vector<8x4xf32>
97+
func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> {
98+
%0 = vector.extract %arg0[1]: vector<8x4xf32> from vector<2x8x4xf32>
99+
return %0 : vector<8x4xf32>
100+
}

0 commit comments

Comments
 (0)