Skip to content

Commit 3afe5d5

Browse files
committed
use 'op.get*Mutable()' when getting layout attrs
Signed-off-by: dchigarev <[email protected]>
1 parent 79e37d8 commit 3afe5d5

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -636,15 +636,12 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
636636
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
637637
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
638638

639-
auto numOffsets = gatherOp.getOffsets().size();
640639
auto layoutRes = xegpu::getDistributeLayoutAttr(gatherOp.getResult());
641640
auto layoutIndices =
642-
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 1));
643-
auto layoutMask =
644-
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 2));
641+
xegpu::getDistributeLayoutAttr(gatherOp.getIndicesMutable());
642+
auto layoutMask = xegpu::getDistributeLayoutAttr(gatherOp.getMaskMutable());
645643
auto layoutPassThru =
646-
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 3));
647-
644+
xegpu::getDistributeLayoutAttr(gatherOp.getPassThruMutable());
648645
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints =
649646
getOpCacheHints(gatherOp);
650647
auto xeGatherOp = xegpu::LoadGatherOp::create(
@@ -654,15 +651,17 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
654651
/*l2_hint=*/cacheHints[1],
655652
/*l3_hint=*/cacheHints[2]);
656653
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes);
657-
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(1), layoutIndices);
658-
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(2), layoutMask);
654+
xegpu::setDistributeLayoutAttr(xeGatherOp.getOffsetsMutable()[0],
655+
layoutIndices);
656+
xegpu::setDistributeLayoutAttr(xeGatherOp.getMaskMutable(), layoutMask);
659657

660658
auto selectOp =
661659
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
662660
xeGatherOp.getResult(), gatherOp.getPassThru());
663-
xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(0), layoutMask);
664-
xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(1), layoutRes);
665-
xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(2), layoutPassThru);
661+
xegpu::setDistributeLayoutAttr(selectOp.getConditionMutable(), layoutMask);
662+
xegpu::setDistributeLayoutAttr(selectOp.getTrueValueMutable(), layoutRes);
663+
xegpu::setDistributeLayoutAttr(selectOp.getFalseValueMutable(),
664+
layoutPassThru);
666665
xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes);
667666

668667
rewriter.replaceOp(gatherOp, selectOp.getResult());
@@ -689,13 +688,12 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
689688
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
690689
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
691690

692-
auto numOffsets = scatterOp.getOffsets().size();
693691
auto layoutIndices =
694-
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 1));
692+
xegpu::getDistributeLayoutAttr(scatterOp.getIndicesMutable());
695693
auto layoutMask =
696-
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 2));
694+
xegpu::getDistributeLayoutAttr(scatterOp.getMaskMutable());
697695
auto layoutVal =
698-
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 3));
696+
xegpu::getDistributeLayoutAttr(scatterOp.getValueToStoreMutable());
699697
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints =
700698
getOpCacheHints(scatterOp);
701699
auto storeOp = xegpu::StoreScatterOp::create(
@@ -705,9 +703,10 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
705703
/*l1_hint=*/cacheHints[0],
706704
/*l2_hint=*/cacheHints[1],
707705
/*l3_hint=*/cacheHints[2]);
708-
xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(0), layoutVal);
709-
xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(2), layoutIndices);
710-
xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(3), layoutMask);
706+
xegpu::setDistributeLayoutAttr(storeOp.getValueMutable(), layoutVal);
707+
xegpu::setDistributeLayoutAttr(storeOp.getOffsetsMutable()[0],
708+
layoutIndices);
709+
xegpu::setDistributeLayoutAttr(storeOp.getMaskMutable(), layoutMask);
711710
rewriter.eraseOp(scatterOp);
712711
return success();
713712
}

0 commit comments

Comments
 (0)