Skip to content

Commit 1a160b8

Browse files
committed
use 'op.get*Mutable()' when getting layout attrs
Signed-off-by: dchigarev <[email protected]>
1 parent 79e37d8 commit 1a160b8

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -636,15 +636,12 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
636636
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
637637
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
638638

639-
auto numOffsets = gatherOp.getOffsets().size();
640639
auto layoutRes = xegpu::getDistributeLayoutAttr(gatherOp.getResult());
641640
auto layoutIndices =
642-
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 1));
643-
auto layoutMask =
644-
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 2));
641+
xegpu::getDistributeLayoutAttr(gatherOp.getIndicesMutable());
642+
auto layoutMask = xegpu::getDistributeLayoutAttr(gatherOp.getMaskMutable());
645643
auto layoutPassThru =
646-
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 3));
647-
644+
xegpu::getDistributeLayoutAttr(gatherOp.getPassThruMutable());
648645
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints =
649646
getOpCacheHints(gatherOp);
650647
auto xeGatherOp = xegpu::LoadGatherOp::create(
@@ -689,13 +686,12 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
689686
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
690687
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
691688

692-
auto numOffsets = scatterOp.getOffsets().size();
693689
auto layoutIndices =
694-
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 1));
690+
xegpu::getDistributeLayoutAttr(scatterOp.getIndicesMutable());
695691
auto layoutMask =
696-
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 2));
692+
xegpu::getDistributeLayoutAttr(scatterOp.getMaskMutable());
697693
auto layoutVal =
698-
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 3));
694+
xegpu::getDistributeLayoutAttr(scatterOp.getValueToStoreMutable());
699695
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints =
700696
getOpCacheHints(scatterOp);
701697
auto storeOp = xegpu::StoreScatterOp::create(

0 commit comments

Comments
 (0)