@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
7272 auto resType =
7373 getTypeConverter ()->convertType <VectorType>(constOp.getType ());
7474
75+ if (!resType)
76+ return rewriter.notifyMatchFailure (loc, " can't convert return type" );
77+
7578 if (resType.isScalable () && !isa<SplatElementsAttr>(constOp.getValue ()))
7679 return rewriter.notifyMatchFailure (
7780 loc,
7881 " Cannot linearize a constant scalable vector that's not a splat" );
7982
80- if (!resType)
81- return rewriter.notifyMatchFailure (loc, " can't convert return type" );
8283 if (!isLessThanTargetBitWidth (constOp, targetVectorBitWidth))
8384 return rewriter.notifyMatchFailure (
8485 loc, " Can't flatten since targetBitWidth <= OpSize" );
@@ -459,6 +460,45 @@ struct LinearizeVectorInsert final
459460private:
460461 unsigned targetVectorBitWidth;
461462};
463+
464+ // / This pattern converts the BitCastOp that works on nD (n > 1)
465+ // / vectors to a BitCastOp that works on linearized vectors.
466+ // / Following,
467+ // / vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
468+ // / is converted to :
469+ // / %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
470+ // / %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
471+ // / %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
472+ struct LinearizeVectorBitCast final
473+ : public OpConversionPattern<vector::BitCastOp> {
474+ using OpConversionPattern::OpConversionPattern;
475+ LinearizeVectorBitCast (
476+ const TypeConverter &typeConverter, MLIRContext *context,
477+ unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
478+ PatternBenefit benefit = 1 )
479+ : OpConversionPattern(typeConverter, context, benefit),
480+ targetVectorBitWidth (targetVectBitWidth) {}
481+ LogicalResult
482+ matchAndRewrite (vector::BitCastOp castOp, OpAdaptor adaptor,
483+ ConversionPatternRewriter &rewriter) const override {
484+ Location loc = castOp.getLoc ();
485+ auto resType = getTypeConverter ()->convertType (castOp.getType ());
486+ if (!resType)
487+ return rewriter.notifyMatchFailure (loc, " can't convert return type." );
488+
489+ if (!isLessThanTargetBitWidth (castOp, targetVectorBitWidth))
490+ return rewriter.notifyMatchFailure (
491+ loc, " Can't flatten since targetBitWidth <= OpSize" );
492+
493+ rewriter.replaceOpWithNewOp <vector::BitCastOp>(castOp, resType,
494+ adaptor.getSource ());
495+ return mlir::success ();
496+ }
497+
498+ private:
499+ unsigned targetVectorBitWidth;
500+ };
501+
462502} // namespace
463503
464504void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality (
@@ -485,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
485525 typeConverter.addTargetMaterialization (materializeCast);
486526 target.markUnknownOpDynamicallyLegal (
487527 [=](Operation *op) -> std::optional<bool > {
488- if ((isa<arith::ConstantOp>(op) ||
528+ if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
489529 op->hasTrait <OpTrait::Vectorizable>())) {
490530 return (isLessThanTargetBitWidth (op, targetBitWidth)
491531 ? typeConverter.isLegal (op)
@@ -494,8 +534,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
494534 return std::nullopt ;
495535 });
496536
497- patterns.add <LinearizeConstant, LinearizeVectorizable>(
498- typeConverter, patterns.getContext (), targetBitWidth);
537+ patterns
538+ .add <LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539+ typeConverter, patterns.getContext (), targetBitWidth);
499540}
500541
501542void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments