@@ -830,16 +830,17 @@ struct WgToSgMultiDimReductionOp
830830 matchAndRewrite (vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
831831 ConversionPatternRewriter &rewriter) const override {
832832 Location loc = op.getLoc ();
833- // Only support reduction with layout and on a single dimension for now.
834833 VectorType srcType = dyn_cast<VectorType>(op.getSource ().getType ());
835834 VectorType accType = dyn_cast<VectorType>(op.getAcc ().getType ());
836835 VectorType resType = dyn_cast<VectorType>(op.getResult ().getType ());
837836 Type elemTy = srcType.getElementType ();
838837 if (!srcType || !accType || !resType)
839838 return failure ();
840839
840+ // Support only 2D vectors
841+ if (srcType.getShape ().size () != 2 && resType.getShape ().size () != 1 )
842+ return failure ();
841843 ArrayRef<int64_t > wgShape = resType.getShape ();
842- // Handle both LayoutAttr and SliceAttr for the op result.
843844 auto layoutName = xegpu::getLayoutName (op->getResult (0 ));
844845 auto sliceAttr = op->getAttrOfType <xegpu::SliceAttr>(layoutName);
845846 if (!sliceAttr || sliceAttr.getRank () != 1 )
@@ -871,7 +872,6 @@ struct WgToSgMultiDimReductionOp
871872 VectorType::get (shapeCastShape, srcType.getElementType ());
872873 auto shapeCast = rewriter.create <vector::ShapeCastOp>(
873874 op.getLoc (), shapeCastTy, sgReduce.getResult ());
874- // TODO: Change it to shapeCast
875875 newReductions.push_back (shapeCast.getResult ());
876876 }
877877
@@ -889,23 +889,23 @@ struct WgToSgMultiDimReductionOp
889889 auto slmTy = MemRefType::get (slmSize, rewriter.getI8Type (), {}, 3 );
890890 auto slm = rewriter.create <memref::AllocaOp>(loc, slmTy);
891891
892- // Create a view for the SLM buffer using xegpu.create_mem_desc
893- SmallVector<int64_t > viewShape ;
892+ // Create a SLM buffer using xegpu.create_mem_desc
893+ SmallVector<int64_t > memDescShape ;
894894 auto srcVecType = dyn_cast<VectorType>(adaptor.getSource ()[0 ].getType ());
895895 ArrayRef<int64_t > srcShape =
896896 srcVecType ? srcVecType.getShape () : ArrayRef<int64_t >();
897897 for (size_t i = 0 ; i < srcShape.size (); ++i) {
898898 if (static_cast <int64_t >(i) == reduceDim) {
899899 // For the reduced dimension, use sgLayoutParent[i]
900- viewShape .push_back (sgLayoutParent[i]);
900+ memDescShape .push_back (sgLayoutParent[i]);
901901 } else {
902902 // For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
903- viewShape .push_back (sgLayoutParent[i] * srcShape[i]);
903+ memDescShape .push_back (sgLayoutParent[i] * srcShape[i]);
904904 }
905905 }
906906
907- auto memDescType = xegpu::MemDescType::get (rewriter.getContext (), viewShape,
908- elemTy, nullptr );
907+ auto memDescType = xegpu::MemDescType::get (rewriter.getContext (),
908+ memDescShape, elemTy, nullptr );
909909 auto memDesc =
910910 rewriter.create <xegpu::CreateMemDescOp>(loc, memDescType, slm);
911911
@@ -951,16 +951,16 @@ struct WgToSgMultiDimReductionOp
951951 // Step 3: Load from SLM for the second reduction
952952 SmallVector<int64_t > slmLoadShape;
953953
954- for (size_t i = 0 ; i < viewShape .size (); ++i) {
954+ for (size_t i = 0 ; i < memDescShape .size (); ++i) {
955955 if (static_cast <int64_t >(i) == reduceDim) {
956- slmLoadShape.push_back (viewShape [i]);
956+ slmLoadShape.push_back (memDescShape [i]);
957957 } else {
958958 int64_t divisor = computeProduct (sgLayoutParent);
959- slmLoadShape.push_back (viewShape [i] / divisor);
959+ slmLoadShape.push_back (memDescShape [i] / divisor);
960960 }
961961 }
962962
963- // Calculate offsets for create_nd_desc
963+ // Calculate offsets for load_matrix op
964964 SmallVector<OpFoldResult> slmLoadOffsets;
965965 for (size_t i = 0 ; i < sgLayoutParent.size (); ++i) {
966966 Value offset = rewriter.createOrFold <index::MulOp>(
@@ -975,8 +975,8 @@ struct WgToSgMultiDimReductionOp
975975 /* layout=*/ nullptr );
976976
977977 // Step 4: Create a constant accumulator for the second reduction
978- // with same vallue as adaptor.getAcc()[0] and shape set to
979- // the non reduce dimension of shapeCastLoad
978+ // with same value as adaptor.getAcc()[0] and shape set to
979+ // the non reduce dimension of load
980980 auto accShape = load.getType ().getShape ();
981981 SmallVector<int64_t > accShapeWithoutReduceDim;
982982 for (size_t i = 0 ; i < accShape.size (); ++i) {
@@ -1180,10 +1180,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
11801180
11811181 target.addDynamicallyLegalOp <vector::MultiDimReductionOp>(
11821182 [=](vector::MultiDimReductionOp op) -> bool {
1183- // Only allow MultiDimReductionOp with a single reduction dimension
1184- if (op.getReductionDims ().size () != 1 )
1185- return true ;
1186-
11871183 // Check if the layout is legal
11881184 return isLegal (xegpu::getDistributeLayoutAttr (op.getResult ()));
11891185 });
0 commit comments