@@ -636,15 +636,12 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
636636 computeOffsets (rewriter, gatherOp, meta.first , meta.second );
637637 Value flatMemref = memrefToIndexPtr (gatherOp, rewriter);
638638
639- auto numOffsets = gatherOp.getOffsets ().size ();
640639 auto layoutRes = xegpu::getDistributeLayoutAttr (gatherOp.getResult ());
641640 auto layoutIndices =
642- xegpu::getDistributeLayoutAttr (gatherOp->getOpOperand (numOffsets + 1 ));
643- auto layoutMask =
644- xegpu::getDistributeLayoutAttr (gatherOp->getOpOperand (numOffsets + 2 ));
641+ xegpu::getDistributeLayoutAttr (gatherOp.getIndicesMutable ());
642+ auto layoutMask = xegpu::getDistributeLayoutAttr (gatherOp.getMaskMutable ());
645643 auto layoutPassThru =
646- xegpu::getDistributeLayoutAttr (gatherOp->getOpOperand (numOffsets + 3 ));
647-
644+ xegpu::getDistributeLayoutAttr (gatherOp.getPassThruMutable ());
648645 SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints =
649646 getOpCacheHints (gatherOp);
650647 auto xeGatherOp = xegpu::LoadGatherOp::create (
@@ -654,15 +651,17 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
654651 /* l2_hint=*/ cacheHints[1 ],
655652 /* l3_hint=*/ cacheHints[2 ]);
656653 xegpu::setDistributeLayoutAttr (xeGatherOp->getOpResult (0 ), layoutRes);
657- xegpu::setDistributeLayoutAttr (xeGatherOp->getOpOperand (1 ), layoutIndices);
658- xegpu::setDistributeLayoutAttr (xeGatherOp->getOpOperand (2 ), layoutMask);
654+ xegpu::setDistributeLayoutAttr (xeGatherOp.getOffsetsMutable ()[0 ],
655+ layoutIndices);
656+ xegpu::setDistributeLayoutAttr (xeGatherOp.getMaskMutable (), layoutMask);
659657
660658 auto selectOp =
661659 arith::SelectOp::create (rewriter, loc, gatherOp.getMask (),
662660 xeGatherOp.getResult (), gatherOp.getPassThru ());
663- xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (0 ), layoutMask);
664- xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (1 ), layoutRes);
665- xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (2 ), layoutPassThru);
661+ xegpu::setDistributeLayoutAttr (selectOp.getConditionMutable (), layoutMask);
662+ xegpu::setDistributeLayoutAttr (selectOp.getTrueValueMutable (), layoutRes);
663+ xegpu::setDistributeLayoutAttr (selectOp.getFalseValueMutable (),
664+ layoutPassThru);
666665 xegpu::setDistributeLayoutAttr (selectOp->getOpResult (0 ), layoutRes);
667666
668667 rewriter.replaceOp (gatherOp, selectOp.getResult ());
@@ -689,13 +688,12 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
689688 computeOffsets (rewriter, scatterOp, meta.first , meta.second );
690689 Value flatMemref = memrefToIndexPtr (scatterOp, rewriter);
691690
692- auto numOffsets = scatterOp.getOffsets ().size ();
693691 auto layoutIndices =
694- xegpu::getDistributeLayoutAttr (scatterOp-> getOpOperand (numOffsets + 1 ));
692+ xegpu::getDistributeLayoutAttr (scatterOp. getIndicesMutable ( ));
695693 auto layoutMask =
696- xegpu::getDistributeLayoutAttr (scatterOp-> getOpOperand (numOffsets + 2 ));
694+ xegpu::getDistributeLayoutAttr (scatterOp. getMaskMutable ( ));
697695 auto layoutVal =
698- xegpu::getDistributeLayoutAttr (scatterOp-> getOpOperand (numOffsets + 3 ));
696+ xegpu::getDistributeLayoutAttr (scatterOp. getValueToStoreMutable ( ));
699697 SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints =
700698 getOpCacheHints (scatterOp);
701699 auto storeOp = xegpu::StoreScatterOp::create (
@@ -705,9 +703,10 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
705703 /* l1_hint=*/ cacheHints[0 ],
706704 /* l2_hint=*/ cacheHints[1 ],
707705 /* l3_hint=*/ cacheHints[2 ]);
708- xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (0 ), layoutVal);
709- xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (2 ), layoutIndices);
710- xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (3 ), layoutMask);
706+ xegpu::setDistributeLayoutAttr (storeOp.getValueMutable (), layoutVal);
707+ xegpu::setDistributeLayoutAttr (storeOp.getOffsetsMutable ()[0 ],
708+ layoutIndices);
709+ xegpu::setDistributeLayoutAttr (storeOp.getMaskMutable (), layoutMask);
711710 rewriter.eraseOp (scatterOp);
712711 return success ();
713712 }
0 commit comments