Skip to content

Commit e154b22

Browse files
committed
generalize to extract cast
1 parent c584ac8 commit e154b22

File tree

2 files changed

+149
-89
lines changed

2 files changed

+149
-89
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 55 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,10 +2387,8 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
23872387
return success();
23882388
}
23892389

2390-
/// Rewrite vector.from_elements(vector.extract, vector.extract, ...) as
2391-
/// vector.shape_cast(vector.extact) if possible.
2392-
///
2393-
/// Example:
2390+
/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2391+
/// on a single extract. Example:
23942392
/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
23952393
/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
23962394
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
@@ -2401,30 +2399,32 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
24012399
///
24022400
/// The requirements for this to be valid are
24032401
///
2404-
/// i) all elements are extracted from the same vector (%source)
2405-
/// ii) the elements form a suffix of %source
2406-
/// iii) the elements are extracted contiguously in ascending order
2402+
/// i) The elements are extracted from the same vector (%source).
2403+
///
2404+
/// ii) The elements form a suffix of %source. Specifically, the number
2405+
/// of elements is the same as the product of the last N dimension sizes
2406+
/// of %source, for some N.
2407+
///
2408+
/// iii) The elements are extracted contiguously in ascending order.
2409+
2410+
class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
24072411

2408-
class FromElementsToShapeCast
2409-
: public OpRewritePattern<FromElementsOp> {
2410-
public:
24112412
using OpRewritePattern::OpRewritePattern;
24122413

24132414
LogicalResult matchAndRewrite(FromElementsOp fromElements,
24142415
PatternRewriter &rewriter) const override {
24152416

2416-
// Left for `rewriteFromElementsAsSplat` to avoid divergent
2417-
// canonicalizations
2418-
if (fromElements.getType().getNumElements() == 1) {
2417+
// Handled by `rewriteFromElementsAsSplat`
2418+
if (fromElements.getType().getNumElements() == 1)
24192419
return failure();
2420-
}
24212420

2422-
// The source of the first element, the position (N-d vector) that the first
2423-
// element is extracted from, and the flattened position (index). These are
2424-
// all obtained in the first iteration of the loop over elements.
2425-
TypedValue<VectorType> firstElementSource;
2426-
ArrayRef<int64_t> firstElementExtractPosition;
2427-
int64_t firstElementExtractIndex;
2421+
// The common source that all elements are extracted from, if one exists.
2422+
TypedValue<VectorType> source;
2423+
// The position of the combined extract operation, if one is created.
2424+
ArrayRef<int64_t> combinedPosition;
2425+
// The expected index of extraction of the current element in the loop, if
2426+
// elements are extracted contiguously in ascending order.
2427+
SmallVector<int64_t> expectedPosition;
24282428

24292429
for (auto [insertIndex, element] :
24302430
llvm::enumerate(fromElements.getElements())) {
@@ -2440,77 +2440,70 @@ class FromElementsToShapeCast
24402440
// Check condition (i) by checking that all elements have same source as
24412441
// the first element.
24422442
if (insertIndex == 0) {
2443-
firstElementSource = extractOp.getVector();
2444-
} else if (extractOp.getVector() != firstElementSource) {
2443+
source = extractOp.getVector();
2444+
} else if (extractOp.getVector() != source) {
24452445
return rewriter.notifyMatchFailure(fromElements,
24462446
"element from different vector");
24472447
}
24482448

2449-
// Obtain the flattened index of extraction from the N-d position.
24502449
ArrayRef<int64_t> position = extractOp.getStaticPosition();
2451-
int64_t extractIndex{0};
2452-
int64_t stride{1};
2453-
assert(position.size() ==
2454-
static_cast<size_t>(firstElementSource.getType().getRank()) &&
2450+
int64_t rank = position.size();
2451+
assert(rank == source.getType().getRank() &&
24552452
"scalar extract must have full rank position");
2456-
for (auto [pos, size] :
2457-
llvm::zip(llvm::reverse(position),
2458-
llvm::reverse(firstElementSource.getType().getShape()))) {
2459-
if (pos == ShapedType::kDynamic) {
2460-
return rewriter.notifyMatchFailure(
2461-
fromElements, "elements not in ascending order (dynamic order)");
2462-
}
2463-
extractIndex += pos * stride;
2464-
stride *= size;
2465-
}
24662453

2467-
// Check condition (ii) using the extraction index of the first element.
2468-
// We check that the position that the first element is extracted
2469-
// from has sufficient trailing 0s. For example, in
2470-
// ```
2471-
// %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2472-
// [...]
2473-
// %n = vector.from_elements %elm0, [...] : vector<12xi8>
2474-
// ```
2475-
// The 2 trailing 0s in the position of extraction of %0 cover 3*4 = 12
2454+
// Check condition (ii) by checking that the position that the first
2455+
// element is extracted from has sufficient trailing 0s. For example, in
2456+
//
2457+
// %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2458+
// [...]
2459+
// %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2460+
//
2461+
// The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
24762462
// elements, which is the number of elements of %n, so this is valid.
24772463
if (insertIndex == 0) {
2478-
const int64_t numFinalElements =
2479-
fromElements.getType().getNumElements();
2480-
int64_t numElementsInSourceSuffix = 1;
2481-
int index = position.size();
2464+
const int64_t numElms = fromElements.getType().getNumElements();
2465+
int64_t numSuffixElms = 1;
2466+
int64_t index = rank;
24822467
while (index > 0 && position[index - 1] == 0 &&
2483-
numElementsInSourceSuffix < numFinalElements) {
2484-
numElementsInSourceSuffix *=
2485-
firstElementSource.getType().getDimSize(index - 1);
2468+
numSuffixElms < numElms) {
2469+
numSuffixElms *= source.getType().getDimSize(index - 1);
24862470
--index;
24872471
}
2488-
if (numElementsInSourceSuffix != numFinalElements) {
2472+
if (numSuffixElms != numElms) {
24892473
return rewriter.notifyMatchFailure(
24902474
fromElements, "elements do not form a suffix of source");
24912475
}
2492-
firstElementExtractIndex = extractIndex;
2493-
firstElementExtractPosition =
2494-
position.drop_back(position.size() - index);
2476+
expectedPosition = llvm::to_vector(position);
2477+
combinedPosition = position.drop_back(rank - index);
24952478
}
24962479

2497-
// Check condition (iii) by checking the index of extraction relative
2498-
// the first element.
2499-
else if (static_cast<int64_t>(insertIndex) + firstElementExtractIndex !=
2500-
extractIndex) {
2480+
// Check condition (iii).
2481+
else if (expectedPosition != position) {
25012482
return rewriter.notifyMatchFailure(
25022483
fromElements, "elements not in ascending order (static order)");
25032484
}
2485+
increment(expectedPosition, source.getType().getShape());
25042486
}
25052487

25062488
auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2507-
fromElements.getLoc(), firstElementSource, firstElementExtractPosition);
2489+
fromElements.getLoc(), source, combinedPosition);
25082490

25092491
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
25102492
fromElements, fromElements.getType(), extracted);
25112493

25122494
return success();
25132495
}
2496+
2497+
/// Increments n-D `indices` by 1 starting from the innermost dimension.
2498+
static void increment(MutableArrayRef<int64_t> indices,
2499+
ArrayRef<int64_t> shape) {
2500+
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2501+
indices[dim] += 1;
2502+
if (indices[dim] < shape[dim])
2503+
break;
2504+
indices[dim] = 0;
2505+
}
2506+
}
25142507
};
25152508

25162509
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,

mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --mlir-disable-threading %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
22

33
// This file contains some tests of folding/canonicalizing vector.from_elements
44

@@ -7,7 +7,7 @@
77
///===----------------------------------------------===//
88

99
// CHECK-LABEL: func @extract_scalar_from_from_elements(
10-
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
10+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
1111
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
1212
// Extract from 0D.
1313
%0 = vector.from_elements %a : vector<f32>
@@ -33,7 +33,7 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
3333
// -----
3434

3535
// CHECK-LABEL: func @extract_1d_from_from_elements(
36-
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
36+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
3737
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
3838
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
3939
// CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
@@ -47,7 +47,7 @@ func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, ve
4747
// -----
4848

4949
// CHECK-LABEL: func @extract_2d_from_from_elements(
50-
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
50+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
5151
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
5252
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
5353
// CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
@@ -61,7 +61,7 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>,
6161
// -----
6262

6363
// CHECK-LABEL: func @from_elements_to_splat(
64-
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
64+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
6565
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
6666
// CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
6767
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
8181

8282
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
8383
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
84-
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
85-
// CHECK: return %[[SHAPE_CAST]] : vector<2xi8>
84+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
85+
// CHECK: return %[[EXTRACT]] : vector<2xi8>
8686
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
8787
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
8888
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
@@ -109,20 +109,13 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8>
109109
return %8 : vector<2x2x2xi8>
110110
}
111111

112-
113112
// -----
114113

115-
// func.func @bar(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
116-
// %0 = vector.extract %arg0[1] : vector<3x4xi8> from vector<2x3x4xi8>
117-
// %1 = vector.shape_cast %0 : vector<3x4xi8> to vector<12xi8>
118-
// return %1 : vector<12xi8>
119-
120114
// CHECK-LABEL: func @source_larger_than_out(
121-
// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>)
122-
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]] [1] : vector<3x4xi8> from vector<2x3x4xi8>
123-
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
124-
// CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
125-
115+
// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>)
116+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8>
117+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
118+
// CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
126119
func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
127120
%0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8>
128121
%1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8>
@@ -140,13 +133,70 @@ func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
140133
return %12 : vector<12xi8>
141134
}
142135

143-
// TODO(newling) add more tests where the source is not the same size as out.
136+
// -----
137+
138+
// This test is similar to `source_larger_than_out` except here the number of elements
139+
// extracted contigously starting from the first position [0,0] could be 6 instead of 3
140+
// and the pattern would still match.
141+
// CHECK-LABEL: func @suffix_with_excess_zeros(
142+
// CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8>
143+
// CHECK: return %[[EXT]] : vector<3xi8>
144+
func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> {
145+
%0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
146+
%1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
147+
%2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
148+
%3 = vector.from_elements %0, %1, %2 : vector<3xi8>
149+
return %3 : vector<3xi8>
150+
}
151+
152+
// -----
153+
154+
// CHECK-LABEL: func @large_source_with_shape_cast_required(
155+
// CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>)
156+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8>
157+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8>
158+
// CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8>
159+
func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> {
160+
%0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8>
161+
%1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8>
162+
%2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8>
163+
%3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8>
164+
%4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8>
165+
return %4 : vector<1x4x1xi8>
166+
}
167+
168+
// -----
169+
170+
// Could match, but handled by `rewriteFromElementsAsSplat`.
171+
// CHECK-LABEL: func @extract_single_elm(
172+
// CHECK-NEXT: vector.extract
173+
// CHECK-NEXT: vector.splat
174+
// CHECK-NEXT: return
175+
func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> {
176+
%0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
177+
%1 = vector.from_elements %0 : vector<1xi8>
178+
return %1 : vector<1xi8>
179+
}
180+
181+
// -----
182+
183+
// CHECK-LABEL: func @negative_source_contiguous_but_not_suffix(
184+
// CHECK-NOT: shape_cast
185+
// CHECK: from_elements
186+
func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> {
187+
%0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
188+
%1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
189+
%2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8>
190+
%3 = vector.from_elements %0, %1, %2 : vector<3xi8>
191+
return %3 : vector<3xi8>
192+
}
144193

145194
// -----
146195

147196
// The extracted elements are recombined into a single vector, but in a new order.
148197
// CHECK-LABEL: func @negative_nonascending_order(
149-
// CHECK-NOT: shape_cast
198+
// CHECK-NOT: shape_cast
199+
// CHECK: from_elements
150200
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
151201
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
152202
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
@@ -157,7 +207,8 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
157207
// -----
158208

159209
// CHECK-LABEL: func @negative_nonstatic_extract(
160-
// CHECK-NOT: shape_cast
210+
// CHECK-NOT: shape_cast
211+
// CHECK: from_elements
161212
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
162213
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
163214
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
@@ -168,7 +219,8 @@ func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 :
168219
// -----
169220

170221
// CHECK-LABEL: func @negative_different_sources(
171-
// CHECK-NOT: shape_cast
222+
// CHECK-NOT: shape_cast
223+
// CHECK: from_elements
172224
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
173225
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
174226
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
@@ -178,9 +230,10 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi
178230

179231
// -----
180232

181-
// CHECK-LABEL: func @negative_source_too_large(
182-
// CHECK-NOT: shape_cast
183-
func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
233+
// CHECK-LABEL: func @negative_source_not_suffix(
234+
// CHECK-NOT: shape_cast
235+
// CHECK: from_elements
236+
func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> {
184237
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
185238
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
186239
%2 = vector.from_elements %0, %1 : vector<2xi8>
@@ -189,13 +242,27 @@ func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
189242

190243
// -----
191244

192-
// The inserted elements are are a subset of the extracted elements.
245+
// The inserted elements are a subset of the extracted elements.
193246
// [0, 1, 2] -> [1, 1, 2]
194247
// CHECK-LABEL: func @negative_nobijection_order(
195-
// CHECK-NOT: shape_cast
248+
// CHECK-NOT: shape_cast
249+
// CHECK: from_elements
196250
func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
197251
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
198252
%1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
199253
%2 = vector.from_elements %0, %0, %1 : vector<3xi8>
200254
return %2 : vector<3xi8>
201255
}
256+
257+
// -----
258+
259+
// CHECK-LABEL: func @negative_source_too_small(
260+
// CHECK-NOT: shape_cast
261+
// CHECK: from_elements
262+
func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> {
263+
%0 = vector.extract %arg0[0] : i8 from vector<2xi8>
264+
%1 = vector.extract %arg0[1] : i8 from vector<2xi8>
265+
%2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8>
266+
return %2 : vector<4xi8>
267+
}
268+

0 commit comments

Comments
 (0)