@@ -293,6 +293,10 @@ struct LinearizeVectorExtract final
293293 LogicalResult
294294 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
295295 ConversionPatternRewriter &rewriter) const override {
296+ // Skip if result is not a vector type
297+ if (!isa<VectorType>(extractOp.getType ()))
298+ return rewriter.notifyMatchFailure (extractOp,
299+ " scalar extract is not supported." );
296300 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
297301 assert (dstTy && " expected 1-D vector type" );
298302
@@ -415,6 +419,32 @@ struct LinearizeVectorBitCast final
415419 }
416420};
417421
422+ // / This pattern converts the SplatOp to work on a linearized vector.
423+ // / Following,
424+ // / vector.splat %value : vector<4x4xf32>
425+ // / is converted to:
426+ // / %out_1d = vector.splat %value : vector<16xf32>
427+ // / %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
428+ struct LinearizeVectorSplat final
429+ : public OpConversionPattern<vector::SplatOp> {
430+ using OpConversionPattern::OpConversionPattern;
431+
432+ LinearizeVectorSplat (const TypeConverter &typeConverter, MLIRContext *context,
433+ PatternBenefit benefit = 1 )
434+ : OpConversionPattern(typeConverter, context, benefit) {}
435+
436+ LogicalResult
437+ matchAndRewrite (vector::SplatOp splatOp, OpAdaptor adaptor,
438+ ConversionPatternRewriter &rewriter) const override {
439+ auto dstTy = getTypeConverter ()->convertType (splatOp.getType ());
440+ if (!dstTy)
441+ return rewriter.notifyMatchFailure (splatOp, " cannot convert type." );
442+ rewriter.replaceOpWithNewOp <vector::SplatOp>(splatOp, adaptor.getInput (),
443+ dstTy);
444+ return success ();
445+ }
446+ };
447+
418448} // namespace
419449
420450// / Return true if the operation `op` does not support scalable vectors and
@@ -501,7 +531,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
501531 const TypeConverter &typeConverter, const ConversionTarget &target,
502532 RewritePatternSet &patterns) {
503533 patterns.add <LinearizeConstantLike, LinearizeVectorizable,
504- LinearizeVectorBitCast>(typeConverter, patterns.getContext ());
534+ LinearizeVectorBitCast, LinearizeVectorSplat>(
535+ typeConverter, patterns.getContext ());
505536}
506537
507538void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments