@@ -98,17 +98,18 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9898}
9999
100100// Extract cache hints from the op attributes if available.
101- static void getOpCacheHints (Operation *op,
102- SmallVector<xegpu::CachePolicyAttr, 3 > &hints) {
103- assert (hints. size () == 3 &&
104- " Expecting a vector of size 3 for l1, l2, l3 hints. " ) ;
101+ static SmallVector<xegpu::CachePolicyAttr, 3 > getOpCacheHints (Operation *op) {
102+ SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints{xegpu::CachePolicyAttr{},
103+ xegpu::CachePolicyAttr{},
104+ xegpu::CachePolicyAttr{}} ;
105105 // get l1, l2, l3 hints from attributes if available.
106106 if (auto l1Attr = op->getAttrOfType <xegpu::CachePolicyAttr>(" l1_hint" ))
107- hints [0 ] = l1Attr;
107+ cacheHints [0 ] = l1Attr;
108108 if (auto l2Attr = op->getAttrOfType <xegpu::CachePolicyAttr>(" l2_hint" ))
109- hints [1 ] = l2Attr;
109+ cacheHints [1 ] = l2Attr;
110110 if (auto l3Attr = op->getAttrOfType <xegpu::CachePolicyAttr>(" l3_hint" ))
111- hints[2 ] = l3Attr;
111+ cacheHints[2 ] = l3Attr;
112+ return cacheHints;
112113}
113114
114115static xegpu::CreateNdDescOp
@@ -389,28 +390,25 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
389390 }
390391 Value indices = gatScatOp.getIndices ();
391392 // Extract indices layout and propagate it to all 'vector' ops created here
392- auto indicesLayout = mlir:: xegpu::getDistributeLayoutAttr (indices);
393+ auto indicesLayout = xegpu::getDistributeLayoutAttr (indices);
393394 VectorType vecType = cast<VectorType>(indices.getType ());
394395
395396 auto strideVector =
396397 vector::BroadcastOp::create (rewriter, loc, vecType, strides.back ());
397- mlir::xegpu::setDistributeLayoutAttr (strideVector->getOpResult (0 ),
398- indicesLayout);
398+ xegpu::setDistributeLayoutAttr (strideVector->getOpResult (0 ), indicesLayout);
399399
400400 auto stridedIndices =
401401 arith::MulIOp::create (rewriter, loc, strideVector.getResult (), indices);
402- mlir::xegpu::setDistributeLayoutAttr (stridedIndices->getOpResult (0 ),
403- indicesLayout);
402+ xegpu::setDistributeLayoutAttr (stridedIndices->getOpResult (0 ), indicesLayout);
404403
405404 auto baseVector = vector::BroadcastOp::create (
406405 rewriter, loc,
407406 VectorType::get (vecType.getShape (), rewriter.getIndexType ()), baseOffset);
408- mlir::xegpu::setDistributeLayoutAttr (baseVector->getOpResult (0 ),
409- indicesLayout);
407+ xegpu::setDistributeLayoutAttr (baseVector->getOpResult (0 ), indicesLayout);
410408
411409 auto result = arith::AddIOp::create (rewriter, loc, baseVector.getResult (),
412410 stridedIndices.getResult ());
413- mlir:: xegpu::setDistributeLayoutAttr (result->getOpResult (0 ), indicesLayout);
411+ xegpu::setDistributeLayoutAttr (result->getOpResult (0 ), indicesLayout);
414412 return result.getResult ();
415413}
416414
@@ -639,37 +637,33 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
639637 Value flatMemref = memrefToIndexPtr (gatherOp, rewriter);
640638
641639 auto numOffsets = gatherOp.getOffsets ().size ();
642- auto layoutRes = mlir::xegpu::getDistributeLayoutAttr (gatherOp.getResult ());
643- auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr (
644- gatherOp->getOpOperand (numOffsets + 1 ));
645- auto layoutMask = mlir::xegpu::getDistributeLayoutAttr (
646- gatherOp->getOpOperand (numOffsets + 2 ));
647- auto layoutPassThru = mlir::xegpu::getDistributeLayoutAttr (
648- gatherOp->getOpOperand (numOffsets + 3 ));
649-
650- SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints{xegpu::CachePolicyAttr{},
651- xegpu::CachePolicyAttr{},
652- xegpu::CachePolicyAttr{}};
653- getOpCacheHints (gatherOp, cacheHints);
640+ auto layoutRes = xegpu::getDistributeLayoutAttr (gatherOp.getResult ());
641+ auto layoutIndices =
642+ xegpu::getDistributeLayoutAttr (gatherOp->getOpOperand (numOffsets + 1 ));
643+ auto layoutMask =
644+ xegpu::getDistributeLayoutAttr (gatherOp->getOpOperand (numOffsets + 2 ));
645+ auto layoutPassThru =
646+ xegpu::getDistributeLayoutAttr (gatherOp->getOpOperand (numOffsets + 3 ));
647+
648+ SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints =
649+ getOpCacheHints (gatherOp);
654650 auto xeGatherOp = xegpu::LoadGatherOp::create (
655651 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask (),
656652 /* chunk_size=*/ IntegerAttr{},
657653 /* l1_hint=*/ cacheHints[0 ],
658654 /* l2_hint=*/ cacheHints[1 ],
659655 /* l3_hint=*/ cacheHints[2 ]);
660- mlir::xegpu::setDistributeLayoutAttr (xeGatherOp->getOpResult (0 ), layoutRes);
661- mlir::xegpu::setDistributeLayoutAttr (xeGatherOp->getOpOperand (1 ),
662- layoutIndices);
663- mlir::xegpu::setDistributeLayoutAttr (xeGatherOp->getOpOperand (2 ),
664- layoutMask);
656+ xegpu::setDistributeLayoutAttr (xeGatherOp->getOpResult (0 ), layoutRes);
657+ xegpu::setDistributeLayoutAttr (xeGatherOp->getOpOperand (1 ), layoutIndices);
658+ xegpu::setDistributeLayoutAttr (xeGatherOp->getOpOperand (2 ), layoutMask);
665659
666660 auto selectOp =
667661 arith::SelectOp::create (rewriter, loc, gatherOp.getMask (),
668662 xeGatherOp.getResult (), gatherOp.getPassThru ());
669- mlir:: xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (0 ), layoutMask);
670- mlir:: xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (2 ),
671- layoutPassThru);
672- mlir:: xegpu::setDistributeLayoutAttr (selectOp->getOpResult (0 ), layoutRes);
663+ xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (0 ), layoutMask);
664+ xegpu::setDistributeLayoutAttr (selectOp->getOpOperand (1 ), layoutRes);
665+ xegpu::setDistributeLayoutAttr (selectOp-> getOpOperand ( 2 ), layoutPassThru);
666+ xegpu::setDistributeLayoutAttr (selectOp->getOpResult (0 ), layoutRes);
673667
674668 rewriter.replaceOp (gatherOp, selectOp.getResult ());
675669 return success ();
@@ -696,27 +690,24 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
696690 Value flatMemref = memrefToIndexPtr (scatterOp, rewriter);
697691
698692 auto numOffsets = scatterOp.getOffsets ().size ();
699- auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr (
700- scatterOp->getOpOperand (numOffsets + 1 ));
701- auto layoutMask = mlir::xegpu::getDistributeLayoutAttr (
702- scatterOp->getOpOperand (numOffsets + 2 ));
703- auto layoutVal = mlir::xegpu::getDistributeLayoutAttr (
704- scatterOp->getOpOperand (numOffsets + 3 ));
705- SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints{xegpu::CachePolicyAttr{},
706- xegpu::CachePolicyAttr{},
707- xegpu::CachePolicyAttr{}};
708- getOpCacheHints (scatterOp, cacheHints);
693+ auto layoutIndices =
694+ xegpu::getDistributeLayoutAttr (scatterOp->getOpOperand (numOffsets + 1 ));
695+ auto layoutMask =
696+ xegpu::getDistributeLayoutAttr (scatterOp->getOpOperand (numOffsets + 2 ));
697+ auto layoutVal =
698+ xegpu::getDistributeLayoutAttr (scatterOp->getOpOperand (numOffsets + 3 ));
699+ SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints =
700+ getOpCacheHints (scatterOp);
709701 auto storeOp = xegpu::StoreScatterOp::create (
710702 rewriter, loc, scatterOp.getValueToStore (), flatMemref, localOffsets,
711703 scatterOp.getMask (),
712704 /* chunk_size=*/ IntegerAttr{},
713705 /* l1_hint=*/ cacheHints[0 ],
714706 /* l2_hint=*/ cacheHints[1 ],
715707 /* l3_hint=*/ cacheHints[2 ]);
716- mlir::xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (0 ), layoutVal);
717- mlir::xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (2 ),
718- layoutIndices);
719- mlir::xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (3 ), layoutMask);
708+ xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (0 ), layoutVal);
709+ xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (2 ), layoutIndices);
710+ xegpu::setDistributeLayoutAttr (storeOp->getOpOperand (3 ), layoutMask);
720711 rewriter.eraseOp (scatterOp);
721712 return success ();
722713 }
0 commit comments