@@ -765,6 +765,110 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
765765 }
766766};
767767
768+ // This pattern transforms the LoadGatherOp with explicit offsets to load
769+ // subgroup data
770+ struct WgToSgLoadGatherOpWithOffset
771+ : public OpConversionPattern<xegpu::LoadGatherOp> {
772+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
773+ LogicalResult
774+ matchAndRewrite (xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
775+ ConversionPatternRewriter &rewriter) const override {
776+
777+ if (!op.getOffsets ())
778+ return failure ();
779+
780+ Location loc = op.getLoc ();
781+ VectorType resultType = dyn_cast<VectorType>(op.getResult ().getType ());
782+ if (!resultType)
783+ return failure ();
784+ ArrayRef<int64_t > wgShape = resultType.getShape ();
785+
786+ xegpu::DistributeLayoutAttr layout =
787+ xegpu::getDistributeLayoutAttr (op.getResult ());
788+ if (!layout || !layout.isForWorkgroup ())
789+ return failure ();
790+
791+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
792+
793+ // The offsets need to be distributed
794+ auto offsetsVecType =
795+ dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ());
796+ auto maskVecType =
797+ dyn_cast<VectorType>(adaptor.getMask ().front ().getType ());
798+ if (!offsetsVecType || !maskVecType ||
799+ offsetsVecType.getShape () != maskVecType.getShape ()) {
800+ return rewriter.notifyMatchFailure (op,
801+ " offsets have not been distributed" );
802+ }
803+
804+ SmallVector<Value> newLoadOps;
805+ auto chunkSizeAttr =
806+ rewriter.getI64IntegerAttr (op.getChunkSize ().value_or (1 ));
807+ VectorType newTy = VectorType::get (sgShape, resultType.getElementType ());
808+ for (auto [offsets, mask] :
809+ llvm::zip (adaptor.getOffsets (), adaptor.getMask ())) {
810+ auto newLoadOp = rewriter.create <xegpu::LoadGatherOp>(
811+ loc, newTy, op.getSource (), offsets, mask, chunkSizeAttr,
812+ op.getL1HintAttr (), op.getL2HintAttr (), op.getL3HintAttr ());
813+ xegpu::setDistributeLayoutAttr (newLoadOp->getResult (0 ),
814+ layout.dropSgLayoutAndData ());
815+ newLoadOps.push_back (newLoadOp);
816+ }
817+ rewriter.replaceOpWithMultiple (op, {newLoadOps});
818+ return success ();
819+ }
820+ };
821+
822+ // This pattern transforms the StoreScatterOp with explicit offsets to store
823+ // subgroup data
824+ struct WgToSgStoreScatterOpWithOffset
825+ : public OpConversionPattern<xegpu::StoreScatterOp> {
826+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
827+ LogicalResult
828+ matchAndRewrite (xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
829+ ConversionPatternRewriter &rewriter) const override {
830+
831+ if (!op.getOffsets ())
832+ return failure ();
833+
834+ Location loc = op.getLoc ();
835+ VectorType valueType = dyn_cast<VectorType>(op.getValue ().getType ());
836+ if (!valueType)
837+ return failure ();
838+
839+ xegpu::DistributeLayoutAttr layout =
840+ xegpu::getDistributeLayoutAttr (op.getValue ());
841+ if (!layout || !layout.isForWorkgroup ())
842+ return failure ();
843+
844+ // The offsets need to be distributed
845+ auto offsetsVecType =
846+ dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ());
847+ auto maskVecType =
848+ dyn_cast<VectorType>(adaptor.getMask ().front ().getType ());
849+ if (!offsetsVecType || !maskVecType ||
850+ offsetsVecType.getShape () != maskVecType.getShape ()) {
851+ return rewriter.notifyMatchFailure (op,
852+ " offsets have not been distributed" );
853+ }
854+
855+ auto chunkSizeOpt = op.getChunkSize ();
856+ int64_t chunkSize = chunkSizeOpt ? static_cast <int64_t >(*chunkSizeOpt) : 1 ;
857+ auto chunkSizeAttr = rewriter.getI64IntegerAttr (chunkSize);
858+ for (auto [val, offs, mask] : llvm::zip (
859+ adaptor.getValue (), adaptor.getOffsets (), adaptor.getMask ())) {
860+ rewriter.create <xegpu::StoreScatterOp>(
861+ loc, val, op.getDest (), offs, mask, chunkSizeAttr, op.getL1HintAttr (),
862+ op.getL2HintAttr (), op.getL3HintAttr ());
863+ // Update the layout attribute to drop sg_layout and sg_data.
864+ if (auto newLayout = layout.dropSgLayoutAndData ())
865+ op->setAttr (" layout" , newLayout);
866+ }
867+ rewriter.eraseOp (op);
868+ return success ();
869+ }
870+ };
871+
768872struct WgToSgLoadMatrixOp : public OpConversionPattern <xegpu::LoadMatrixOp> {
769873 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
770874 LogicalResult
@@ -826,8 +930,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
826930 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
827931 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
828932 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
829- WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
830- patterns.getContext ());
933+ WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
934+ WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
935+ WgToSgStoreMatrixOp>(patterns.getContext ());
831936}
832937} // namespace xegpu
833938} // namespace mlir
@@ -952,6 +1057,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
9521057 return isLegal (xegpu::getDistributeLayoutAttr (op.getResult ()));
9531058 });
9541059
1060+ target.addDynamicallyLegalOp <xegpu::LoadGatherOp>(
1061+ [=](xegpu::LoadGatherOp op) -> bool {
1062+ auto layout = xegpu::getDistributeLayoutAttr (op.getResult ());
1063+ return isLegal (layout);
1064+ });
1065+
1066+ target.addDynamicallyLegalOp <xegpu::StoreScatterOp>(
1067+ [=](xegpu::StoreScatterOp op) -> bool {
1068+ // Check if the layout attribute is present on the result.
1069+ auto layout = op->getAttrOfType <xegpu::LayoutAttr>(" layout" );
1070+ if (!layout)
1071+ return true ;
1072+ return isLegal (layout);
1073+ });
1074+
9551075 target.addDynamicallyLegalOp <vector::BroadcastOp>(
9561076 [=](vector::BroadcastOp op) -> bool {
9571077 return isLegal (xegpu::getDistributeLayoutAttr (op.getResult ()));
0 commit comments