Skip to content

Commit d82a741

Browse files
author
Amy Zhuang
committed
[mlir][vector] Support index type in ND to 1D vector linearization
Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.
1 parent c9fa319 commit d82a741

File tree

4 files changed

+79
-33
lines changed

4 files changed

+79
-33
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
399399
/// the ops to get converted properly.
400400
void populateVectorLinearizeTypeConversionsAndLegality(
401401
TypeConverter &typeConverter, RewritePatternSet &patterns,
402-
ConversionTarget &target, unsigned targetBitWidth);
402+
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
403403

404404
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
405405
/// vector shuffle operations.
406406
void populateVectorLinearizeShuffleLikeOpsPatterns(
407407
const TypeConverter &typeConverter, RewritePatternSet &patterns,
408-
ConversionTarget &target, unsigned targetBitWidth);
408+
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
409409

410410
} // namespace vector
411411
} // namespace mlir

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,44 @@
2525

2626
using namespace mlir;
2727

28-
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
28+
static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
29+
unsigned targetBitWidth) {
2930
auto resultTypes = op->getResultTypes();
3031
for (auto resType : resultTypes) {
3132
VectorType vecType = dyn_cast<VectorType>(resType);
32-
// Reject index since getElementTypeBitWidth will abort for Index types.
33-
if (!vecType || vecType.getElementType().isIndex())
33+
if (!vecType)
34+
return false;
35+
bool isIndexTy = vecType.getElementType().isIndex();
36+
// Reject index if `indexBitWidth` is not supplied.
37+
if (isIndexTy && indexBitWidth == 0)
3438
return false;
3539
// There are no dimension to fold if it is a 0-D vector.
3640
if (vecType.getRank() == 0)
3741
return false;
3842
unsigned trailingVecDimBitWidth =
39-
vecType.getShape().back() * vecType.getElementTypeBitWidth();
43+
vecType.getShape().back() *
44+
(isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
4045
if (trailingVecDimBitWidth >= targetBitWidth)
4146
return false;
4247
}
4348
return true;
4449
}
4550

46-
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
51+
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
52+
unsigned targetBitWidth) {
4753
VectorType vecType = dyn_cast<VectorType>(t);
48-
// Reject index since getElementTypeBitWidth will abort for Index types.
49-
if (!vecType || vecType.getElementType().isIndex())
54+
if (!vecType)
55+
return false;
56+
bool isIndexTy = vecType.getElementType().isIndex();
57+
// Reject index if `indexBitWidth` is not supplied.
58+
if (isIndexTy && indexBitWidth == 0)
5059
return false;
5160
// There are no dimension to fold if it is a 0-D vector.
5261
if (vecType.getRank() == 0)
5362
return false;
5463
unsigned trailingVecDimBitWidth =
55-
vecType.getShape().back() * vecType.getElementTypeBitWidth();
64+
vecType.getShape().back() *
65+
(isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
5666
return trailingVecDimBitWidth <= targetBitWidth;
5767
}
5868

@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
6171
using OpConversionPattern::OpConversionPattern;
6272
LinearizeConstant(
6373
const TypeConverter &typeConverter, MLIRContext *context,
74+
unsigned indexBitWidth = 0,
6475
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
6576
PatternBenefit benefit = 1)
6677
: OpConversionPattern(typeConverter, context, benefit),
67-
targetVectorBitWidth(targetVectBitWidth) {}
78+
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
79+
}
6880
LogicalResult
6981
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
7082
ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
7991

8092
if (!resType)
8193
return rewriter.notifyMatchFailure(loc, "can't convert return type");
82-
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
94+
if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
8395
return rewriter.notifyMatchFailure(
8496
loc, "Can't flatten since targetBitWidth <= OpSize");
8597
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
93105
}
94106

95107
private:
108+
unsigned indexBitWidth;
96109
unsigned targetVectorBitWidth;
97110
};
98111

@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
103116
public:
104117
LinearizeVectorizable(
105118
const TypeConverter &typeConverter, MLIRContext *context,
119+
unsigned indexBitWidth = 0,
106120
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
107121
PatternBenefit benefit = 1)
108122
: OpTraitConversionPattern(typeConverter, context, benefit),
109-
targetVectorBitWidth(targetVectBitWidth) {}
123+
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
124+
}
110125
LogicalResult
111126
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112127
ConversionPatternRewriter &rewriter) const override {
113-
if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
128+
if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
114129
return rewriter.notifyMatchFailure(
115130
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
116131
FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
123138
}
124139

125140
private:
141+
unsigned indexBitWidth;
126142
unsigned targetVectorBitWidth;
127143
};
128144

@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
142158
using OpConversionPattern::OpConversionPattern;
143159
LinearizeVectorExtractStridedSlice(
144160
const TypeConverter &typeConverter, MLIRContext *context,
161+
unsigned indexBitWidth = 0,
145162
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146163
PatternBenefit benefit = 1)
147164
: OpConversionPattern(typeConverter, context, benefit),
148-
targetVectorBitWidth(targetVectBitWidth) {}
165+
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
166+
}
149167

150168
LogicalResult
151169
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
156174
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
157175
return rewriter.notifyMatchFailure(extractOp,
158176
"scalable vectors are not supported.");
159-
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
177+
if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
178+
targetVectorBitWidth))
160179
return rewriter.notifyMatchFailure(
161180
extractOp, "Can't flatten since targetBitWidth <= OpSize");
162181

@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
237256
}
238257

239258
private:
259+
unsigned indexBitWidth;
240260
unsigned targetVectorBitWidth;
241261
};
242262

@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
256276
using OpConversionPattern::OpConversionPattern;
257277
LinearizeVectorShuffle(
258278
const TypeConverter &typeConverter, MLIRContext *context,
279+
unsigned indexBitWidth = 0,
259280
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
260281
PatternBenefit benefit = 1)
261282
: OpConversionPattern(typeConverter, context, benefit),
262-
targetVectorBitWidth(targetVectBitWidth) {}
283+
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
284+
}
263285

264286
LogicalResult
265287
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
273295
shuffleOp.getV2VectorType().isScalable() ||
274296
dstType.isScalable()) &&
275297
"scalable vectors are not supported.");
276-
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
298+
if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
299+
targetVectorBitWidth))
277300
return rewriter.notifyMatchFailure(
278301
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
279302

@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
312335
}
313336

314337
private:
338+
unsigned indexBitWidth;
315339
unsigned targetVectorBitWidth;
316340
};
317341

@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
329353
using OpConversionPattern::OpConversionPattern;
330354
LinearizeVectorExtract(
331355
const TypeConverter &typeConverter, MLIRContext *context,
356+
unsigned indexBitWidth = 0,
332357
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
333358
PatternBenefit benefit = 1)
334359
: OpConversionPattern(typeConverter, context, benefit),
335-
targetVectorBitWidth(targetVectBitWidth) {}
360+
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
361+
}
336362
LogicalResult
337363
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
338364
ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
345371
cast<VectorType>(dstTy).isScalable())
346372
return rewriter.notifyMatchFailure(extractOp,
347373
"scalable vectors are not supported.");
348-
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
374+
if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
375+
targetVectorBitWidth))
349376
return rewriter.notifyMatchFailure(
350377
extractOp, "Can't flatten since targetBitWidth <= OpSize");
351378

@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
374401
}
375402

376403
private:
404+
unsigned indexBitWidth;
377405
unsigned targetVectorBitWidth;
378406
};
379407

@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
392420
using OpConversionPattern::OpConversionPattern;
393421
LinearizeVectorInsert(
394422
const TypeConverter &typeConverter, MLIRContext *context,
423+
unsigned indexBitWidth = 0,
395424
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
396425
PatternBenefit benefit = 1)
397426
: OpConversionPattern(typeConverter, context, benefit),
398-
targetVectorBitWidth(targetVectBitWidth) {}
427+
indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
428+
}
399429
LogicalResult
400430
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
401431
ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
407437
"scalable vectors are not supported.");
408438

409439
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
410-
targetVectorBitWidth))
440+
indexBitWidth, targetVectorBitWidth))
411441
return rewriter.notifyMatchFailure(
412442
insertOp, "Can't flatten since targetBitWidth < OpSize");
413443

@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
457487
}
458488

459489
private:
490+
unsigned indexBitWidth;
460491
unsigned targetVectorBitWidth;
461492
};
462493
} // namespace
463494

464495
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
465496
TypeConverter &typeConverter, RewritePatternSet &patterns,
466-
ConversionTarget &target, unsigned targetBitWidth) {
497+
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
467498

468499
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
469500
if (!isLinearizableVector(type))
@@ -488,29 +519,31 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
488519
[=](Operation *op) -> std::optional<bool> {
489520
if ((isa<arith::ConstantOp>(op) ||
490521
op->hasTrait<OpTrait::Vectorizable>())) {
491-
return (isLessThanTargetBitWidth(op, targetBitWidth)
522+
return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
492523
? typeConverter.isLegal(op)
493524
: true);
494525
}
495526
return std::nullopt;
496527
});
497528

498529
patterns.add<LinearizeConstant, LinearizeVectorizable>(
499-
typeConverter, patterns.getContext(), targetBitWidth);
530+
typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
500531
}
501532

502533
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
503534
const TypeConverter &typeConverter, RewritePatternSet &patterns,
504-
ConversionTarget &target, unsigned int targetBitWidth) {
535+
ConversionTarget &target, unsigned indexBitWidth,
536+
unsigned int targetBitWidth) {
505537
target.addDynamicallyLegalOp<vector::ShuffleOp>(
506538
[=](vector::ShuffleOp shuffleOp) -> bool {
507-
return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
539+
return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
540+
targetBitWidth)
508541
? (typeConverter.isLegal(shuffleOp) &&
509542
cast<mlir::VectorType>(shuffleOp.getResult().getType())
510543
.getRank() == 1)
511544
: true;
512545
});
513546
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
514547
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
515-
typeConverter, patterns.getContext(), targetBitWidth);
548+
typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
516549
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
22
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
33
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
4+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
45

56
// ALL-LABEL: test_linearize
67
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
1415
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
1516

1617
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
18+
19+
// INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
1720
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
1821

1922
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
4548
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
4649

4750
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
51+
52+
// INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
4853
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
4954

5055
// DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
7984

8085
// -----
8186

82-
// ALL-LABEL: test_index_no_linearize
83-
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
84-
// ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
87+
// ALL-LABEL: test_index_linearize
88+
func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
89+
// DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
90+
// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
91+
// BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
92+
// INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
8593
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
8694
return %0 : vector<2x2xindex>
8795
}
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
122130

123131
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
124132
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
133+
// INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
125134
// ALL: return %[[RES]] : vector<2x[2]xf32>
126135
return %2 : vector<2x[2]xf32>
127136
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
853853
registry.insert<vector::VectorDialect>();
854854
}
855855

856+
Option<unsigned> indexBitwidth{*this, "index-bitwidth",
857+
llvm::cl::desc("Bitwidth of the index type"),
858+
llvm::cl::init(0)};
859+
856860
Option<unsigned> targetVectorBitwidth{
857861
*this, "target-vector-bitwidth",
858862
llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
866870
ConversionTarget target(*context);
867871

868872
vector::populateVectorLinearizeTypeConversionsAndLegality(
869-
typeConverter, patterns, target, targetVectorBitwidth);
873+
typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
870874
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
871-
typeConverter, patterns, target, targetVectorBitwidth);
875+
typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
872876
if (failed(applyPartialConversion(getOperation(), target,
873877
std::move(patterns))))
874878
return signalPassFailure();

0 commit comments

Comments
 (0)