Skip to content

Commit be48849

Browse files
committed
push further with the separation of concerns
1 parent 86ceb57 commit be48849

File tree

4 files changed

+205
-144
lines changed

4 files changed

+205
-144
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -392,24 +392,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
392392
void populateVectorTransposeNarrowTypeRewritePatterns(
393393
RewritePatternSet &patterns, PatternBenefit benefit = 1);
394394

395-
/// Populate `typeConverter` and `conversionTarget` with the definition of
396-
/// legal types and operations, for the specific case where vectors with
397-
/// trailing dimensions of size greater than `targetBitWidth` are legal.
398-
void populateVectorLinearizeBitWidthTargetAndConverter(
399-
TypeConverter &typeConverter, ConversionTarget &conversionTarget,
400-
unsigned targetBitWidth);
401-
402-
/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
403-
/// converting ConstantLike, Vectorizable, and vector::BitCast.
395+
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
396+
/// This registers (1) which operations are legal and hence should not be
397+
/// linearized, (2) what converted types are (rank-1 vectors) and how to
398+
/// materialze the conversion (with shape_cast)
399+
///
400+
/// Note: the set of legal operations can be extended by a user if for example
401+
/// certain rank>1 vectors are considered valid, but adding additional
402+
/// dynamically legal ops to `conversionTarget`.
403+
void populateForVectorLinearize(TypeConverter &typeConverter,
404+
ConversionTarget &conversionTarget);
405+
406+
/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
407+
/// contains patterns for converting ConstantLike, Vectorizable, and
408+
/// vector::BitCast ops.
404409
void populateVectorLinearizeBasePatterns(const TypeConverter &,
405-
RewritePatternSet &patterns,
406-
const ConversionTarget &);
410+
const ConversionTarget &,
411+
RewritePatternSet &patterns);
407412

408413
/// Populates `patterns` for linearizing ND (N >= 2) vector operations
409414
/// to 1D vector shuffle operations.
410415
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
411-
RewritePatternSet &patterns,
412-
const ConversionTarget &);
416+
const ConversionTarget &,
417+
RewritePatternSet &patterns);
413418

414419
} // namespace vector
415420
} // namespace mlir

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 52 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,10 @@ struct LinearizeConstantLike final
6262
if (op->getNumResults() != 1)
6363
return rewriter.notifyMatchFailure(loc, "expected 1 result");
6464

65-
const TypeConverter &converter = *getTypeConverter();
65+
const TypeConverter &typeConverter = *getTypeConverter();
6666
auto resType =
67-
converter.convertType<VectorType>(op->getResult(0).getType());
68-
69-
if (!resType)
70-
return rewriter.notifyMatchFailure(loc, "can't convert return type");
67+
typeConverter.convertType<VectorType>(op->getResult(0).getType());
68+
assert(resType && "expected 1-D vector type");
7169

7270
StringAttr attrName = rewriter.getStringAttr("value");
7371
Attribute value = op->getAttr(attrName);
@@ -80,7 +78,7 @@ struct LinearizeConstantLike final
8078
return failure();
8179

8280
FailureOr<Operation *> convertResult =
83-
convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
81+
convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
8482
if (failed(convertResult))
8583
return failure();
8684

@@ -244,14 +242,6 @@ struct LinearizeVectorShuffle final
244242
VectorType dstType =
245243
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
246244
assert(dstType && "vector type destination expected.");
247-
// The assert is used because vector.shuffle does not support scalable
248-
// vectors.
249-
bool scalable = shuffleOp.getV1VectorType().isScalable() ||
250-
shuffleOp.getV2VectorType().isScalable() ||
251-
dstType.isScalable();
252-
if (scalable)
253-
return rewriter.notifyMatchFailure(shuffleOp,
254-
"scalable vectors are not supported.");
255245

256246
Value vec1 = adaptor.getV1();
257247
Value vec2 = adaptor.getV2();
@@ -270,7 +260,7 @@ struct LinearizeVectorShuffle final
270260
}
271261

272262
// For each value in the mask, we generate the indices of the source vectors
273-
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
263+
// that need to be shuffled to the destination vector. If shuffleSliceLen >
274264
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
275265
// elements) instead of scalars.
276266
ArrayRef<int64_t> mask = shuffleOp.getMask();
@@ -309,14 +299,7 @@ struct LinearizeVectorExtract final
309299
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
310300
ConversionPatternRewriter &rewriter) const override {
311301
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
312-
if (!dstTy)
313-
return rewriter.notifyMatchFailure(extractOp,
314-
"expected n-D vector type.");
315-
316-
if (extractOp.getVector().getType().isScalable() ||
317-
cast<VectorType>(dstTy).isScalable())
318-
return rewriter.notifyMatchFailure(extractOp,
319-
"scalable vectors are not supported.");
302+
assert(dstTy && "expected 1-D vector type");
320303

321304
// Dynamic position is not supported.
322305
if (extractOp.hasDynamicPosition())
@@ -367,9 +350,6 @@ struct LinearizeVectorInsert final
367350
VectorType dstTy = getTypeConverter()->convertType<VectorType>(
368351
insertOp.getDestVectorType());
369352
assert(dstTy && "vector type destination expected.");
370-
if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
371-
return rewriter.notifyMatchFailure(insertOp,
372-
"scalable vectors are not supported.");
373353

374354
// dynamic position is not supported
375355
if (insertOp.hasDynamicPosition())
@@ -436,11 +416,8 @@ struct LinearizeVectorBitCast final
436416
LogicalResult
437417
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
438418
ConversionPatternRewriter &rewriter) const override {
439-
Location loc = castOp.getLoc();
440419
auto resType = getTypeConverter()->convertType(castOp.getType());
441-
if (!resType)
442-
return rewriter.notifyMatchFailure(loc, "can't convert return type.");
443-
420+
assert(resType && "expected 1-D vector type");
444421
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
445422
adaptor.getSource());
446423
return mlir::success();
@@ -449,56 +426,15 @@ struct LinearizeVectorBitCast final
449426

450427
} // namespace
451428

452-
/// If `type` is VectorType with trailing dimension of (bit) size greater than
453-
/// or equal to `targetBitWidth`, its defining op is considered legal.
454-
static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
455-
456-
VectorType vecType = dyn_cast<VectorType>(type);
457-
458-
if (!vecType)
459-
return true;
460-
461-
// The width of the type 'index' is unbounded (and therefore potentially above
462-
// the target width).
463-
if (vecType.getElementType().isIndex())
464-
return true;
465-
466-
unsigned finalDimSize =
467-
vecType.getRank() == 0 ? 0 : vecType.getShape().back();
468-
469-
unsigned trailingVecDimBitWidth =
470-
finalDimSize * vecType.getElementTypeBitWidth();
471-
472-
return trailingVecDimBitWidth >= targetBitWidth;
473-
}
474-
475-
static SmallVector<std::pair<Type, unsigned>>
476-
getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
477-
478-
if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
479-
auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
480-
? targetBitWidth + 1
481-
: targetBitWidth;
482-
return {{insertOp.getValueToStoreType(), w}};
483-
}
484-
auto resultTypes = op->getResultTypes();
485-
SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
486-
resultsWithBitWidth.reserve(resultTypes.size());
487-
for (Type type : resultTypes) {
488-
resultsWithBitWidth.push_back({type, targetBitWidth});
489-
}
490-
return resultsWithBitWidth;
491-
}
492-
493429
/// Return true if the operation `op` does not support scalable vectors and
494-
/// has at least 1 scalable vector result.
495-
static bool legalBecauseScalable(Operation *op) {
496-
497-
bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
498-
op->hasTrait<OpTrait::Vectorizable>() ||
499-
isa<vector::BitCastOp>(op);
500-
501-
if (scalableSupported)
430+
/// has at least 1 scalable vector result. These ops should all eventually
431+
/// support scalable vectors, and this function should be removed.
432+
static bool isNotLinearizableBecauseScalable(Operation *op) {
433+
434+
bool unsupported =
435+
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
436+
op);
437+
if (!unsupported)
502438
return false;
503439

504440
// Check if any of the results is a scalable vector type.
@@ -512,73 +448,74 @@ static bool legalBecauseScalable(Operation *op) {
512448
return containsScalableResult;
513449
}
514450

515-
static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
451+
static bool isNotLinearizable(Operation *op) {
516452

517453
// Only ops that are in the vector dialect, are ConstantLike, or
518-
// are Vectorizable might be linearized currently, so legalize the others.
519-
bool opIsVectorDialect = op->getDialect()->getNamespace() ==
520-
vector::VectorDialect::getDialectNamespace();
521-
if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
522-
!op->hasTrait<OpTrait::Vectorizable>())
454+
// are Vectorizable might be linearized currently.
455+
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
456+
StringRef opDialect = op->getDialect()->getNamespace();
457+
bool unsupported = (opDialect != vectorDialect) &&
458+
!op->hasTrait<OpTrait::ConstantLike>() &&
459+
!op->hasTrait<OpTrait::Vectorizable>();
460+
if (unsupported)
523461
return true;
524462

525-
// Some ops will not be linearized if they have scalable vector results.
526-
if (legalBecauseScalable(op))
463+
// Some ops currently don't support scalable vectors.
464+
if (isNotLinearizableBecauseScalable(op))
527465
return true;
528466

529-
// Check on bitwidths.
530-
auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
531-
return std::any_of(typesToCheck.begin(), typesToCheck.end(),
532-
[&](std::pair<Type, unsigned> typeWidth) {
533-
return legalBecauseOfBitwidth(typeWidth.first,
534-
typeWidth.second);
535-
});
467+
return false;
536468
}
537469

538-
void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
539-
TypeConverter &typeConverter, ConversionTarget &target,
540-
unsigned targetBitWidth) {
470+
void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
471+
ConversionTarget &target) {
541472

542-
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
543-
if (!isLinearizableVector(type))
473+
auto convertType = [](Type type) -> std::optional<Type> {
474+
VectorType vectorType = dyn_cast<VectorType>(type);
475+
if (!vectorType || !isLinearizableVector(vectorType))
544476
return type;
545477

546-
return VectorType::get(type.getNumElements(), type.getElementType(),
547-
type.isScalable());
548-
});
478+
VectorType linearizedType =
479+
VectorType::get(vectorType.getNumElements(),
480+
vectorType.getElementType(), vectorType.isScalable());
481+
return linearizedType;
482+
};
483+
typeConverter.addConversion(convertType);
549484

550485
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
551486
Location loc) -> Value {
552-
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
553-
!isa<VectorType>(type))
487+
if (inputs.size() != 1)
554488
return nullptr;
555-
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
556-
};
557489

490+
Value value = inputs.front();
491+
if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
492+
return nullptr;
493+
494+
return builder.create<vector::ShapeCastOp>(loc, type, value);
495+
};
558496
typeConverter.addSourceMaterialization(materializeCast);
559497
typeConverter.addTargetMaterialization(materializeCast);
560498

561499
target.markUnknownOpDynamicallyLegal(
562500
[=](Operation *op) -> std::optional<bool> {
563-
bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
564-
if (isDynamicallyLegal)
501+
if (isNotLinearizable(op))
565502
return true;
566-
567-
bool shapeUnchanged = typeConverter.isLegal(op);
568-
return shapeUnchanged;
503+
// This will return true if, for all operand and result types `t`,
504+
// convertType(t) = t. This is true if there are no rank>=2 vectors.
505+
return typeConverter.isLegal(op);
569506
});
570507
}
571508

572509
void mlir::vector::populateVectorLinearizeBasePatterns(
573-
const TypeConverter &typeConverter, RewritePatternSet &patterns,
574-
const ConversionTarget &target) {
510+
const TypeConverter &typeConverter, const ConversionTarget &target,
511+
RewritePatternSet &patterns) {
575512
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
576513
LinearizeVectorBitCast>(typeConverter, patterns.getContext());
577514
}
578515

579516
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
580-
const TypeConverter &typeConverter, RewritePatternSet &patterns,
581-
const ConversionTarget &target) {
517+
const TypeConverter &typeConverter, const ConversionTarget &target,
518+
RewritePatternSet &patterns) {
582519
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
583520
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
584521
typeConverter, patterns.getContext());

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
2-
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
3-
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
2+
3+
// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
4+
// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
45

56
// ALL-LABEL: test_linearize
67
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
9798

9899
// ALL-LABEL: test_index_no_linearize
99100
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
100-
// ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
101+
// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
101102
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
102103
return %0 : vector<2x2xindex>
103104
}

0 commit comments

Comments
 (0)