@@ -97,15 +97,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797 return success ();
9898}
9999
100- static LogicalResult gatherScatterPreconditions (PatternRewriter &rewriter,
101- Operation *op, Type baseType) {
102- auto srcTy = dyn_cast<MemRefType>(baseType);
103- if (!srcTy)
104- return rewriter.notifyMatchFailure (op, " Expects memref source" );
105-
106- return success ();
107- }
108-
109100static xegpu::CreateNdDescOp
110101createNdDescriptor (PatternRewriter &rewriter, Location loc,
111102 xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -610,9 +601,9 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
610601
611602 LogicalResult matchAndRewrite (vector::GatherOp gatherOp,
612603 PatternRewriter &rewriter) const override {
613- if ( failed ( gatherScatterPreconditions (rewriter, gatherOp,
614- gatherOp. getBase (). getType ())) )
615- return failure ( );
604+ auto srcTy = dyn_cast<MemRefType>(gatherOp. getBase (). getType ());
605+ if (!srcTy )
606+ return rewriter. notifyMatchFailure (gatherOp, " Expects memref source " );
616607
617608 Location loc = gatherOp.getLoc ();
618609 VectorType vectorType = gatherOp.getVectorType ();
@@ -645,9 +636,9 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
645636
646637 LogicalResult matchAndRewrite (vector::ScatterOp scatterOp,
647638 PatternRewriter &rewriter) const override {
648- if ( failed ( gatherScatterPreconditions (rewriter, scatterOp,
649- scatterOp. getBase (). getType ())) )
650- return failure ( );
639+ auto srcTy = dyn_cast<MemRefType>(scatterOp. getBase (). getType ());
640+ if (!srcTy )
641+ return rewriter. notifyMatchFailure (scatterOp, " Expects memref source " );
651642
652643 Location loc = scatterOp.getLoc ();
653644 auto meta = computeMemrefMeta (scatterOp, rewriter);
0 commit comments