@@ -685,6 +685,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
685685 }
686686};
687687
688+ // This pattern transforms the LoadGatherOp with explicit offsets to load
689+ // subgroup data, similar to WgToSgLoadNdOpWithOffset.
690+ struct WgToSgLoadGatherOpWithOffset
691+ : public OpConversionPattern<xegpu::LoadGatherOp> {
692+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
693+ LogicalResult
694+ matchAndRewrite (xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
695+ ConversionPatternRewriter &rewriter) const override {
696+
697+ if (!op.getOffsets ())
698+ return failure ();
699+
700+ Location loc = op.getLoc ();
701+ VectorType resultType = op.getResult ().getType ();
702+ ArrayRef<int64_t > wgShape = resultType.getShape ();
703+
704+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
705+ if (!layout || !layout.getSgLayout ())
706+ return failure ();
707+
708+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
709+
710+ SmallVector<Value> newLoadOps;
711+ auto chunkSizeAttr = rewriter.getI64IntegerAttr (op.getChunkSize ().value_or (1 ));
712+ VectorType newTy = VectorType::get (sgShape, resultType.getElementType ());
713+ for (auto [offsets, mask] :
714+ llvm::zip (adaptor.getOffsets (), adaptor.getMask ())) {
715+ auto newLoadOp = rewriter.create <xegpu::LoadGatherOp>(
716+ loc, newTy, op.getSource (), offsets, mask,
717+ chunkSizeAttr, op.getL1HintAttr (), op.getL2HintAttr (),
718+ op.getL3HintAttr ());
719+ xegpu::setLayoutAttr (newLoadOp->getResult (0 ),
720+ layout.dropSgLayoutAndData ());
721+ newLoadOps.push_back (newLoadOp);
722+ }
723+ rewriter.replaceOpWithMultiple (op, {newLoadOps});
724+ return success ();
725+ }
726+ };
727+
728+ // This pattern transforms the StoreScatterOp with explicit offsets to store
729+ // subgroup data, similar to WgToSgStoreNdOpWithOffset.
730+ struct WgToSgStoreScatterOpWithOffset
731+ : public OpConversionPattern<xegpu::StoreScatterOp> {
732+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
733+ LogicalResult
734+ matchAndRewrite (xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
735+ ConversionPatternRewriter &rewriter) const override {
736+
737+ if (!op.getOffsets ())
738+ return failure ();
739+
740+ Location loc = op.getLoc ();
741+ VectorType valueType = dyn_cast<VectorType>(op.getValue ().getType ());
742+ if (!valueType)
743+ return failure ();
744+
745+ ArrayRef<int64_t > wgShape = valueType.getShape ();
746+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getValue ());
747+ if (!layout || !layout.getSgLayout ())
748+ return failure ();
749+
750+ auto chunkSizeOpt = op.getChunkSize ();
751+ int64_t chunkSize = chunkSizeOpt ? static_cast <int64_t >(*chunkSizeOpt) : 1 ;
752+ auto chunkSizeAttr = rewriter.getI64IntegerAttr (chunkSize);
753+ for (auto [val, offs, mask] : llvm::zip (
754+ adaptor.getValue (), adaptor.getOffsets (), adaptor.getMask ())) {
755+ rewriter.create <xegpu::StoreScatterOp>(
756+ loc, val, op.getDest (), offs, mask, chunkSizeAttr,
757+ op.getL1HintAttr (), op.getL2HintAttr (), op.getL3HintAttr ());
758+ // Update the layout_result_0 attribute to drop sg_layout and sg_data.
759+ if (auto layoutAttr =
760+ op->getAttrOfType <xegpu::LayoutAttr>(" layout_result_0" )) {
761+ if (auto newLayout = layoutAttr.dropSgLayoutAndData ())
762+ op->setAttr (" layout_result_0" , newLayout);
763+ }
764+ }
765+ rewriter.eraseOp (op);
766+ return success ();
767+ }
768+ };
769+
688770} // namespace
689771
690772namespace mlir {
@@ -694,7 +776,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
694776 WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
695777 WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
696778 WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
697- WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
779+ WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
780+ WgToSgStoreScatterOpWithOffset>(
698781 patterns.getContext ());
699782}
700783} // namespace xegpu
@@ -815,6 +898,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
815898 return isLegal (xegpu::getLayoutAttr (op.getResult ()));
816899 });
817900
901+ target.addDynamicallyLegalOp <xegpu::LoadGatherOp>(
902+ [=](xegpu::LoadGatherOp op) -> bool {
903+ auto layout = xegpu::getLayoutAttr (op.getResult ());
904+ return isLegal (layout);
905+ });
906+
907+ target.addDynamicallyLegalOp <xegpu::StoreScatterOp>(
908+ [=](xegpu::StoreScatterOp op) -> bool {
909+ // Check if the layout attribute is present on the result.
910+ auto layout = op->getAttrOfType <xegpu::LayoutAttr>(" layout_result_0" );
911+ if (!layout)
912+ return true ;
913+ return isLegal (layout);
914+ });
915+
818916 target.addDynamicallyLegalOp <xegpu::ConvertLayoutOp>(
819917 [=](xegpu::ConvertLayoutOp op) -> bool {
820918 return isLegal (op.getInputLayout ()) && isLegal (op.getTargetLayout ());
0 commit comments