1111// ===----------------------------------------------------------------------===//
1212
1313#include " mlir/Dialect/Arith/IR/Arith.h"
14+ #include " mlir/Dialect/UB/IR/UBOps.h"
1415#include " mlir/Dialect/Vector/IR/VectorOps.h"
1516#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1617#include " mlir/IR/Attributes.h"
@@ -56,40 +57,71 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
5657 return trailingVecDimBitWidth <= targetBitWidth;
5758}
5859
60+ static FailureOr<Attribute>
61+ linearizeConstAttr (Location loc, ConversionPatternRewriter &rewriter,
62+ VectorType resType, Attribute value) {
63+ if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
64+ if (resType.isScalable () && !isa<SplatElementsAttr>(value))
65+ return rewriter.notifyMatchFailure (
66+ loc,
67+ " Cannot linearize a constant scalable vector that's not a splat" );
68+
69+ return dstElementsAttr.reshape (resType);
70+ }
71+
72+ if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
73+ return poisonAttr;
74+
75+ return rewriter.notifyMatchFailure (loc, " unsupported attr type" );
76+ }
77+
5978namespace {
60- struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61- using OpConversionPattern::OpConversionPattern;
62- LinearizeConstant (
79+ struct LinearizeConstantLike final
80+ : OpTraitConversionPattern<OpTrait::ConstantLike> {
81+ using OpTraitConversionPattern::OpTraitConversionPattern;
82+
83+ LinearizeConstantLike (
6384 const TypeConverter &typeConverter, MLIRContext *context,
6485 unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
6586 PatternBenefit benefit = 1 )
66- : OpConversionPattern (typeConverter, context, benefit),
87+ : OpTraitConversionPattern (typeConverter, context, benefit),
6788 targetVectorBitWidth (targetVectBitWidth) {}
6889 LogicalResult
69- matchAndRewrite (arith::ConstantOp constOp, OpAdaptor adaptor ,
90+ matchAndRewrite (Operation *op, ArrayRef<Value> operands ,
7091 ConversionPatternRewriter &rewriter) const override {
71- Location loc = constOp.getLoc ();
92+ Location loc = op->getLoc ();
93+ if (op->getNumResults () != 1 )
94+ return rewriter.notifyMatchFailure (loc, " expected 1 result" );
95+
96+ const TypeConverter &converter = *getTypeConverter ();
7297 auto resType =
73- getTypeConverter ()-> convertType <VectorType>(constOp .getType ());
98+ converter. convertType <VectorType>(op-> getResult ( 0 ) .getType ());
7499
75100 if (!resType)
76101 return rewriter.notifyMatchFailure (loc, " can't convert return type" );
77102
78- if (resType.isScalable () && !isa<SplatElementsAttr>(constOp.getValue ()))
79- return rewriter.notifyMatchFailure (
80- loc,
81- " Cannot linearize a constant scalable vector that's not a splat" );
82-
83- if (!isLessThanTargetBitWidth (constOp, targetVectorBitWidth))
103+ if (!isLessThanTargetBitWidth (op, targetVectorBitWidth))
84104 return rewriter.notifyMatchFailure (
85105 loc, " Can't flatten since targetBitWidth <= OpSize" );
86- auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue ());
87- if (!dstElementsAttr)
88- return rewriter.notifyMatchFailure (loc, " unsupported attr type" );
89106
90- dstElementsAttr = dstElementsAttr.reshape (resType);
91- rewriter.replaceOpWithNewOp <arith::ConstantOp>(constOp, resType,
92- dstElementsAttr);
107+ StringAttr attrName = rewriter.getStringAttr (" value" );
108+ Attribute value = op->getAttr (attrName);
109+ if (!value)
110+ return rewriter.notifyMatchFailure (loc, " no 'value' attr" );
111+
112+ FailureOr<Attribute> newValue =
113+ linearizeConstAttr (loc, rewriter, resType, value);
114+ if (failed (newValue))
115+ return failure ();
116+
117+ FailureOr<Operation *> convertResult =
118+ convertOpResultTypes (op, /* operands=*/ {}, converter, rewriter);
119+ if (failed (convertResult))
120+ return failure ();
121+
122+ Operation *newOp = *convertResult;
123+ newOp->setAttr (attrName, *newValue);
124+ rewriter.replaceOp (op, newOp);
93125 return success ();
94126 }
95127
@@ -525,7 +557,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
525557 typeConverter.addTargetMaterialization (materializeCast);
526558 target.markUnknownOpDynamicallyLegal (
527559 [=](Operation *op) -> std::optional<bool > {
528- if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
560+ if ((isa<vector::BitCastOp>(op) ||
561+ op->hasTrait <OpTrait::ConstantLike>() ||
529562 op->hasTrait <OpTrait::Vectorizable>())) {
530563 return (isLessThanTargetBitWidth (op, targetBitWidth)
531564 ? typeConverter.isLegal (op)
@@ -534,9 +567,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
534567 return std::nullopt ;
535568 });
536569
537- patterns
538- . add <LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539- typeConverter, patterns. getContext (), targetBitWidth);
570+ patterns. add <LinearizeConstantLike, LinearizeVectorizable,
571+ LinearizeVectorBitCast>(typeConverter, patterns. getContext (),
572+ targetBitWidth);
540573}
541574
542575void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments