Skip to content

Commit 54a8299

Browse files
committed
Remove 'gatherScatterPreconditions' function
Signed-off-by: dchigarev <[email protected]>
1 parent eb5fbc7 commit 54a8299

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797
return success();
9898
}
9999

100-
static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
101-
Operation *op, Type baseType) {
102-
auto srcTy = dyn_cast<MemRefType>(baseType);
103-
if (!srcTy)
104-
return rewriter.notifyMatchFailure(op, "Expects memref source");
105-
106-
return success();
107-
}
108-
109100
static xegpu::CreateNdDescOp
110101
createNdDescriptor(PatternRewriter &rewriter, Location loc,
111102
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -610,9 +601,9 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
610601

611602
LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
612603
PatternRewriter &rewriter) const override {
613-
if (failed(gatherScatterPreconditions(rewriter, gatherOp,
614-
gatherOp.getBase().getType())))
615-
return failure();
604+
auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
605+
if (!srcTy)
606+
return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
616607

617608
Location loc = gatherOp.getLoc();
618609
VectorType vectorType = gatherOp.getVectorType();
@@ -645,9 +636,9 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
645636

646637
LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
647638
PatternRewriter &rewriter) const override {
648-
if (failed(gatherScatterPreconditions(rewriter, scatterOp,
649-
scatterOp.getBase().getType())))
650-
return failure();
639+
auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
640+
if (!srcTy)
641+
return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
651642

652643
Location loc = scatterOp.getLoc();
653644
auto meta = computeMemrefMeta(scatterOp, rewriter);

0 commit comments

Comments
 (0)