@@ -726,7 +726,6 @@ struct UnrealizedConversionCastOpPattern
726726 }
727727};
728728
729- // This pattern distributes arith.constant op into subgroup-level constants
730729struct WgToSgArithConstantOp : public OpConversionPattern <arith::ConstantOp> {
731730 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
732731
@@ -756,8 +755,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
756755 auto sgAttr = DenseElementsAttr::get (newType, singleVal);
757756 auto cstOp =
758757 arith::ConstantOp::create (rewriter, op.getLoc (), newType, sgAttr);
759- if (auto newLayout = layout.dropSgLayoutAndData ())
760- xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ), newLayout);
758+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
759+ if (sliceAttr.isForSubgroup ())
760+ xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ),
761+ sliceAttr.dropSgLayoutAndData ());
762+ } else if (auto layoutAttr =
763+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
764+ if (auto newLayout = layoutAttr.dropSgLayoutAndData ())
765+ xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ), newLayout);
766+ }
761767 SmallVector<Value> newConsts (count, cstOp);
762768
763769 rewriter.replaceOpWithMultiple (op, {newConsts});
@@ -815,6 +821,191 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
815821 }
816822};
817823
824+ // Pattern to distribute vector.multi_dim_reduction op to subgroup level.
825+ struct WgToSgMultiDimReductionOp
826+ : public OpConversionPattern<vector::MultiDimReductionOp> {
827+ using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
828+
829+ LogicalResult
830+ matchAndRewrite (vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
831+ ConversionPatternRewriter &rewriter) const override {
832+ Location loc = op.getLoc ();
833+ // Only support reduction with layout and on a single dimension for now.
834+ VectorType srcType = dyn_cast<VectorType>(op.getSource ().getType ());
835+ VectorType accType = dyn_cast<VectorType>(op.getAcc ().getType ());
836+ VectorType resType = dyn_cast<VectorType>(op.getResult ().getType ());
837+ Type elemTy = srcType.getElementType ();
838+ if (!srcType || !accType || !resType)
839+ return failure ();
840+
841+ ArrayRef<int64_t > wgShape = resType.getShape ();
842+ // Handle both LayoutAttr and SliceAttr for the op result.
843+ auto layoutName = xegpu::getLayoutName (op->getResult (0 ));
844+ auto sliceAttr = op->getAttrOfType <xegpu::SliceAttr>(layoutName);
845+ if (!sliceAttr || sliceAttr.getRank () != 1 )
846+ return failure ();
847+
848+ SmallVector<int64_t > dims =
849+ llvm::to_vector (sliceAttr.getDims ().asArrayRef ());
850+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, sliceAttr).first ;
851+
852+ int64_t reduceDim = dims[0 ];
853+
854+ // Step 1: Subgroup-level reduction
855+ // Each subgroup reduces its local tile.
856+ SmallVector<Value> newReductions;
857+ VectorType newType = VectorType::get (sgShape, srcType.getElementType ());
858+ SmallVector<int64_t > shapeCastShape = sgShape;
859+ if (reduceDim == 0 )
860+ shapeCastShape.insert (shapeCastShape.begin (), 1 );
861+ else
862+ shapeCastShape.push_back (1 );
863+ for (auto [sgSrc, sgAcc] :
864+ llvm::zip (adaptor.getSource (), adaptor.getAcc ())) {
865+ auto sgReduce = rewriter.create <vector::MultiDimReductionOp>(
866+ op.getLoc (), newType, op.getKind (), sgSrc, sgAcc,
867+ op.getReductionDims ());
868+ // Compute the shape for the shape cast: set reducedDim to 1, keep other
869+ // dims as sgShape
870+ auto shapeCastTy =
871+ VectorType::get (shapeCastShape, srcType.getElementType ());
872+ auto shapeCast = rewriter.create <vector::ShapeCastOp>(
873+ op.getLoc (), shapeCastTy, sgReduce.getResult ());
874+ // TODO: Change it to shapeCast
875+ newReductions.push_back (shapeCast.getResult ());
876+ }
877+
878+ rewriter.setInsertionPoint (op);
879+
880+ // Get layout of the source tensor
881+ SmallVector<int64_t > sgLayoutParent =
882+ sliceAttr.getParent ().getSgLayoutAsInt ();
883+
884+ // Allocate SLM
885+ auto bitWidth = elemTy.getIntOrFloatBitWidth ();
886+ auto flattenFactor = bitWidth / 8 ;
887+ auto slmSize =
888+ resType.getNumElements () * sgLayoutParent[reduceDim] * flattenFactor;
889+ auto slmTy = MemRefType::get (slmSize, rewriter.getI8Type (), {}, 3 );
890+ auto slm = rewriter.create <memref::AllocaOp>(loc, slmTy);
891+
892+ // Create a view for the SLM buffer using xegpu.create_mem_desc
893+ SmallVector<int64_t > viewShape;
894+ auto srcVecType = dyn_cast<VectorType>(adaptor.getSource ()[0 ].getType ());
895+ ArrayRef<int64_t > srcShape =
896+ srcVecType ? srcVecType.getShape () : ArrayRef<int64_t >();
897+ for (size_t i = 0 ; i < srcShape.size (); ++i) {
898+ if (static_cast <int64_t >(i) == reduceDim) {
899+ // For the reduced dimension, use sgLayoutParent[i]
900+ viewShape.push_back (sgLayoutParent[i]);
901+ } else {
902+ // For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
903+ viewShape.push_back (sgLayoutParent[i] * srcShape[i]);
904+ }
905+ }
906+
907+ auto memDescType = xegpu::MemDescType::get (rewriter.getContext (), viewShape,
908+ elemTy, nullptr );
909+ auto memDesc =
910+ rewriter.create <xegpu::CreateMemDescOp>(loc, memDescType, slm);
911+
912+ // Step 2: Store subgroup results to SLM (shared local memory)
913+ // SLM layout: sgLayout same as srcLayout, sgData is shapeCastShape
914+ SmallVector<int64_t > slmSgData = shapeCastShape;
915+
916+ // Get subgroup id and delinearize
917+ auto sgId = rewriter.create <gpu::SubgroupIdOp>(loc, rewriter.getIndexType (),
918+ nullptr );
919+
920+ SmallVector<Value> srcSgLayoutDim (sgLayoutParent.size ());
921+
922+ for (size_t i = 0 ; i < sgLayoutParent.size (); i++) {
923+ srcSgLayoutDim[i] =
924+ arith::ConstantIndexOp::create (rewriter, loc, sgLayoutParent[i]);
925+ }
926+
927+ auto sgIdVec =
928+ affine::delinearizeIndex (rewriter, loc, sgId, srcSgLayoutDim);
929+ if (failed (sgIdVec))
930+ return failure ();
931+ SmallVector<Value> sgIds = *sgIdVec;
932+
933+ // Calculate offsets for store_matrix
934+ SmallVector<OpFoldResult> slmStoreOffsets;
935+ for (size_t i = 0 ; i < sgLayoutParent.size (); ++i) {
936+ Value offset = rewriter.createOrFold <index::MulOp>(
937+ loc, sgIds[i],
938+ arith::ConstantIndexOp::create (rewriter, loc, slmSgData[i]));
939+ slmStoreOffsets.push_back (offset);
940+ }
941+
942+ // Store subgroup result to SLM
943+ rewriter.create <xegpu::StoreMatrixOp>(
944+ loc, newReductions[0 ], memDesc.getResult (),
945+ ArrayRef<OpFoldResult>(slmStoreOffsets),
946+ /* layout=*/ nullptr );
947+
948+ // Barrier to synchronize subgroups
949+ rewriter.create <gpu::BarrierOp>(loc);
950+
951+ // Step 3: Load from SLM for the second reduction
952+ SmallVector<int64_t > slmLoadShape;
953+
954+ for (size_t i = 0 ; i < viewShape.size (); ++i) {
955+ if (static_cast <int64_t >(i) == reduceDim) {
956+ slmLoadShape.push_back (viewShape[i]);
957+ } else {
958+ int64_t divisor = computeProduct (sgLayoutParent);
959+ slmLoadShape.push_back (viewShape[i] / divisor);
960+ }
961+ }
962+
963+ // Calculate offsets for create_nd_desc
964+ SmallVector<OpFoldResult> slmLoadOffsets;
965+ for (size_t i = 0 ; i < sgLayoutParent.size (); ++i) {
966+ Value offset = rewriter.createOrFold <index::MulOp>(
967+ loc, sgIds[i],
968+ arith::ConstantIndexOp::create (rewriter, loc, slmLoadShape[i]));
969+ slmLoadOffsets.push_back (offset);
970+ }
971+
972+ auto load = rewriter.create <xegpu::LoadMatrixOp>(
973+ loc, VectorType::get (slmLoadShape, elemTy), memDesc,
974+ llvm::ArrayRef<OpFoldResult>({slmLoadOffsets}),
975+ /* layout=*/ nullptr );
976+
977+ // Step 4: Create a constant accumulator for the second reduction
978+ // with same vallue as adaptor.getAcc()[0] and shape set to
979+ // the non reduce dimension of shapeCastLoad
980+ auto accShape = load.getType ().getShape ();
981+ SmallVector<int64_t > accShapeWithoutReduceDim;
982+ for (size_t i = 0 ; i < accShape.size (); ++i) {
983+ if (static_cast <int64_t >(i) != reduceDim)
984+ accShapeWithoutReduceDim.push_back (accShape[i]);
985+ }
986+ auto accTy = VectorType::get (accShapeWithoutReduceDim, elemTy);
987+ auto accConstOp = adaptor.getAcc ()[0 ].getDefiningOp <arith::ConstantOp>();
988+ Attribute accSplatValue;
989+ if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(
990+ accConstOp ? accConstOp.getValue () : nullptr )) {
991+ accSplatValue =
992+ denseAttr.isSplat () ? denseAttr.getSplatValue <Attribute>() : nullptr ;
993+ }
994+ if (!accSplatValue)
995+ return failure ();
996+ auto accValue = rewriter.create <arith::ConstantOp>(
997+ loc, accTy, DenseElementsAttr::get (accTy, accSplatValue));
998+ // Step 5: Perform the second reduction
999+ VectorType secondReduceVecType =
1000+ VectorType::get (accShapeWithoutReduceDim, srcType.getElementType ());
1001+ auto secondReduce = rewriter.create <vector::MultiDimReductionOp>(
1002+ loc, secondReduceVecType, op.getKind (), load, accValue,
1003+ op.getReductionDims ());
1004+ rewriter.replaceOpWithMultiple (op, {secondReduce.getResult ()});
1005+ return success ();
1006+ }
1007+ };
1008+
8181009} // namespace
8191010
8201011namespace mlir {
@@ -826,8 +1017,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
8261017 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
8271018 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
8281019 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
829- WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
830- patterns.getContext ());
1020+ WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
1021+ WgToSgMultiDimReductionOp>( patterns.getContext ());
8311022}
8321023} // namespace xegpu
8331024} // namespace mlir
@@ -987,6 +1178,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
9871178 return isLegal (layout);
9881179 });
9891180
1181+ target.addDynamicallyLegalOp <vector::MultiDimReductionOp>(
1182+ [=](vector::MultiDimReductionOp op) -> bool {
1183+ // Only allow MultiDimReductionOp with a single reduction dimension
1184+ if (op.getReductionDims ().size () != 1 )
1185+ return true ;
1186+
1187+ // Check if the layout is legal
1188+ return isLegal (xegpu::getDistributeLayoutAttr (op.getResult ()));
1189+ });
1190+
9901191 target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
9911192 [=](UnrealizedConversionCastOp op) {
9921193 return llvm::is_contained (existingCastOps, op.getOperation ());
0 commit comments