From c3e598690dc120de36bbe14500936dc4202299e5 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 29 Aug 2025 20:45:52 +0000 Subject: [PATCH 1/4] Add pattern for reduction --- .../Transforms/XeGPUWgToSgDistribute.cpp | 211 +++++++++++++++++- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 13 ++ 2 files changed, 219 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0b7fe81facfce..54a98970a0fc0 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -726,7 +726,6 @@ struct UnrealizedConversionCastOpPattern } }; -// This pattern distributes arith.constant op into subgroup-level constants struct WgToSgArithConstantOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -756,8 +755,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (auto newLayout = layout.dropSgLayoutAndData()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + if (auto sliceAttr = dyn_cast_if_present(layout)) { + if (sliceAttr.isForSubgroup()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), + sliceAttr.dropSgLayoutAndData()); + } else if (auto layoutAttr = + dyn_cast_if_present(layout)) { + if (auto newLayout = layoutAttr.dropSgLayoutAndData()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + } SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); @@ -815,6 +821,191 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern { } }; +// Pattern to distribute vector.multi_dim_reduction op to subgroup level. +struct WgToSgMultiDimReductionOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Only support reduction with layout and on a single dimension for now. + VectorType srcType = dyn_cast(op.getSource().getType()); + VectorType accType = dyn_cast(op.getAcc().getType()); + VectorType resType = dyn_cast(op.getResult().getType()); + Type elemTy = srcType.getElementType(); + if (!srcType || !accType || !resType) + return failure(); + + ArrayRef wgShape = resType.getShape(); + // Handle both LayoutAttr and SliceAttr for the op result. + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto sliceAttr = op->getAttrOfType(layoutName); + if (!sliceAttr || sliceAttr.getRank() != 1) + return failure(); + + SmallVector dims = + llvm::to_vector(sliceAttr.getDims().asArrayRef()); + SmallVector sgShape = getSgShapeAndCount(wgShape, sliceAttr).first; + + int64_t reduceDim = dims[0]; + + // Step 1: Subgroup-level reduction + // Each subgroup reduces its local tile. + SmallVector newReductions; + VectorType newType = VectorType::get(sgShape, srcType.getElementType()); + SmallVector shapeCastShape = sgShape; + if (reduceDim == 0) + shapeCastShape.insert(shapeCastShape.begin(), 1); + else + shapeCastShape.push_back(1); + for (auto [sgSrc, sgAcc] : + llvm::zip(adaptor.getSource(), adaptor.getAcc())) { + auto sgReduce = rewriter.create( + op.getLoc(), newType, op.getKind(), sgSrc, sgAcc, + op.getReductionDims()); + // Compute the shape for the shape cast: set reducedDim to 1, keep other + // dims as sgShape + auto shapeCastTy = + VectorType::get(shapeCastShape, srcType.getElementType()); + auto shapeCast = rewriter.create( + op.getLoc(), shapeCastTy, sgReduce.getResult()); + // TODO: Change it to shapeCast + newReductions.push_back(shapeCast.getResult()); + } + + rewriter.setInsertionPoint(op); + + // Get layout of the source tensor + SmallVector sgLayoutParent = + sliceAttr.getParent().getSgLayoutAsInt(); + + // Allocate SLM + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + auto flattenFactor = bitWidth / 8; + auto slmSize = + resType.getNumElements() * sgLayoutParent[reduceDim] * flattenFactor; + auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3); + auto slm = rewriter.create(loc, slmTy); + + // Create a view for the SLM buffer using xegpu.create_mem_desc + SmallVector viewShape; + auto srcVecType = dyn_cast(adaptor.getSource()[0].getType()); + ArrayRef srcShape = + srcVecType ? srcVecType.getShape() : ArrayRef(); + for (size_t i = 0; i < srcShape.size(); ++i) { + if (static_cast(i) == reduceDim) { + // For the reduced dimension, use sgLayoutParent[i] + viewShape.push_back(sgLayoutParent[i]); + } else { + // For other dimensions, multiply sgLayoutParent[i] by sgShape[i] + viewShape.push_back(sgLayoutParent[i] * srcShape[i]); + } + } + + auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), viewShape, + elemTy, nullptr); + auto memDesc = + rewriter.create(loc, memDescType, slm); + + // Step 2: Store subgroup results to SLM (shared local memory) + // SLM layout: sgLayout same as srcLayout, sgData is shapeCastShape + SmallVector slmSgData = shapeCastShape; + + // Get subgroup id and delinearize + auto sgId = rewriter.create(loc, rewriter.getIndexType(), + nullptr); + + SmallVector srcSgLayoutDim(sgLayoutParent.size()); + + for (size_t i = 0; i < sgLayoutParent.size(); i++) { + srcSgLayoutDim[i] = + arith::ConstantIndexOp::create(rewriter, loc, sgLayoutParent[i]); + } + + auto sgIdVec = + affine::delinearizeIndex(rewriter, loc, sgId, srcSgLayoutDim); + if (failed(sgIdVec)) + return failure(); + SmallVector sgIds = *sgIdVec; + + // Calculate offsets for store_matrix + SmallVector slmStoreOffsets; + for (size_t i = 0; i < sgLayoutParent.size(); ++i) { + Value offset = rewriter.createOrFold( + loc, sgIds[i], + arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i])); + slmStoreOffsets.push_back(offset); + } + + // Store subgroup result to SLM + rewriter.create( + loc, newReductions[0], memDesc.getResult(), + ArrayRef(slmStoreOffsets), + /*layout=*/nullptr); + + // Barrier to synchronize subgroups + rewriter.create(loc); + + // Step 3: Load from SLM for the second reduction + SmallVector slmLoadShape; + + for (size_t i = 0; i < viewShape.size(); ++i) { + if (static_cast(i) == reduceDim) { + slmLoadShape.push_back(viewShape[i]); + } else { + int64_t divisor = computeProduct(sgLayoutParent); + slmLoadShape.push_back(viewShape[i] / divisor); + } + } + + // Calculate offsets for create_nd_desc + SmallVector slmLoadOffsets; + for (size_t i = 0; i < sgLayoutParent.size(); ++i) { + Value offset = rewriter.createOrFold( + loc, sgIds[i], + arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i])); + slmLoadOffsets.push_back(offset); + } + + auto load = rewriter.create( + loc, VectorType::get(slmLoadShape, elemTy), memDesc, + llvm::ArrayRef({slmLoadOffsets}), + /*layout=*/nullptr); + + // Step 4: Create a constant accumulator for the second reduction + // with same vallue as adaptor.getAcc()[0] and shape set to + // the non reduce dimension of shapeCastLoad + auto accShape = load.getType().getShape(); + SmallVector accShapeWithoutReduceDim; + for (size_t i = 0; i < accShape.size(); ++i) { + if (static_cast(i) != reduceDim) + accShapeWithoutReduceDim.push_back(accShape[i]); + } + auto accTy = VectorType::get(accShapeWithoutReduceDim, elemTy); + auto accConstOp = adaptor.getAcc()[0].getDefiningOp(); + Attribute accSplatValue; + if (auto denseAttr = dyn_cast_or_null( + accConstOp ? accConstOp.getValue() : nullptr)) { + accSplatValue = + denseAttr.isSplat() ? denseAttr.getSplatValue() : nullptr; + } + if (!accSplatValue) + return failure(); + auto accValue = rewriter.create( + loc, accTy, DenseElementsAttr::get(accTy, accSplatValue)); + // Step 5: Perform the second reduction + VectorType secondReduceVecType = + VectorType::get(accShapeWithoutReduceDim, srcType.getElementType()); + auto secondReduce = rewriter.create( + loc, secondReduceVecType, op.getKind(), load, accValue, + op.getReductionDims()); + rewriter.replaceOpWithMultiple(op, {secondReduce.getResult()}); + return success(); + } +}; + } // namespace namespace mlir { @@ -826,8 +1017,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, - WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( - patterns.getContext()); + WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, + WgToSgMultiDimReductionOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -987,6 +1178,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); + target.addDynamicallyLegalOp( + [=](vector::MultiDimReductionOp op) -> bool { + // Only allow MultiDimReductionOp with a single reduction dimension + if (op.getReductionDims().size() != 1) + return true; + + // Check if the layout is legal + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](UnrealizedConversionCastOp op) { return llvm::is_contained(existingCastOps, op.getOperation()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 32157a7911f62..1b417d752edcc 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -321,4 +321,17 @@ gpu.module @test_distribution { xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> gpu.return } + + //CHECK-LABEL: vector_reduce + gpu.func @vector_reduce(%src: memref<256x128xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] + : vector<256x128xf32> to vector<128xf32> + gpu.return + } } From 100341dec307fbab0612abc49fe38d27239d177a Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 3 Sep 2025 17:23:28 +0000 Subject: [PATCH 2/4] Add CHECKS --- .../Transforms/XeGPUWgToSgDistribute.cpp | 34 ++++++++----------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 26 ++++++++++++++ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 54a98970a0fc0..fe5026203ad34 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -830,7 +830,6 @@ struct WgToSgMultiDimReductionOp matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - // Only support reduction with layout and on a single dimension for now. VectorType srcType = dyn_cast(op.getSource().getType()); VectorType accType = dyn_cast(op.getAcc().getType()); VectorType resType = dyn_cast(op.getResult().getType()); @@ -838,8 +837,10 @@ struct WgToSgMultiDimReductionOp if (!srcType || !accType || !resType) return failure(); + // Support only 2D vectors + if (srcType.getShape().size() != 2 && resType.getShape().size() != 1) + return failure(); ArrayRef wgShape = resType.getShape(); - // Handle both LayoutAttr and SliceAttr for the op result. auto layoutName = xegpu::getLayoutName(op->getResult(0)); auto sliceAttr = op->getAttrOfType(layoutName); if (!sliceAttr || sliceAttr.getRank() != 1) @@ -871,7 +872,6 @@ struct WgToSgMultiDimReductionOp VectorType::get(shapeCastShape, srcType.getElementType()); auto shapeCast = rewriter.create( op.getLoc(), shapeCastTy, sgReduce.getResult()); - // TODO: Change it to shapeCast newReductions.push_back(shapeCast.getResult()); } @@ -889,23 +889,23 @@ struct WgToSgMultiDimReductionOp auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3); auto slm = rewriter.create(loc, slmTy); - // Create a view for the SLM buffer using xegpu.create_mem_desc - SmallVector viewShape; + // Create a SLM buffer using xegpu.create_mem_desc + SmallVector memDescShape; auto srcVecType = dyn_cast(adaptor.getSource()[0].getType()); ArrayRef srcShape = srcVecType ? srcVecType.getShape() : ArrayRef(); for (size_t i = 0; i < srcShape.size(); ++i) { if (static_cast(i) == reduceDim) { // For the reduced dimension, use sgLayoutParent[i] - viewShape.push_back(sgLayoutParent[i]); + memDescShape.push_back(sgLayoutParent[i]); } else { // For other dimensions, multiply sgLayoutParent[i] by sgShape[i] - viewShape.push_back(sgLayoutParent[i] * srcShape[i]); + memDescShape.push_back(sgLayoutParent[i] * srcShape[i]); } } - auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), viewShape, - elemTy, nullptr); + auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), + memDescShape, elemTy, nullptr); auto memDesc = rewriter.create(loc, memDescType, slm); @@ -951,16 +951,16 @@ struct WgToSgMultiDimReductionOp // Step 3: Load from SLM for the second reduction SmallVector slmLoadShape; - for (size_t i = 0; i < viewShape.size(); ++i) { + for (size_t i = 0; i < memDescShape.size(); ++i) { if (static_cast(i) == reduceDim) { - slmLoadShape.push_back(viewShape[i]); + slmLoadShape.push_back(memDescShape[i]); } else { int64_t divisor = computeProduct(sgLayoutParent); - slmLoadShape.push_back(viewShape[i] / divisor); + slmLoadShape.push_back(memDescShape[i] / divisor); } } - // Calculate offsets for create_nd_desc + // Calculate offsets for load_matrix op SmallVector slmLoadOffsets; for (size_t i = 0; i < sgLayoutParent.size(); ++i) { Value offset = rewriter.createOrFold( @@ -975,8 +975,8 @@ struct WgToSgMultiDimReductionOp /*layout=*/nullptr); // Step 4: Create a constant accumulator for the second reduction - // with same vallue as adaptor.getAcc()[0] and shape set to - // the non reduce dimension of shapeCastLoad + // with same value as adaptor.getAcc()[0] and shape set to + // the non reduce dimension of load auto accShape = load.getType().getShape(); SmallVector accShapeWithoutReduceDim; for (size_t i = 0; i < accShape.size(); ++i) { @@ -1180,10 +1180,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp( [=](vector::MultiDimReductionOp op) -> bool { - // Only allow MultiDimReductionOp with a single reduction dimension - if (op.getReductionDims().size() != 1) - return true; - // Check if the layout is legal return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 1b417d752edcc..50bcc4341291e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -323,7 +323,33 @@ gpu.module @test_distribution { } //CHECK-LABEL: vector_reduce + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @vector_reduce(%src: memref<256x128xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32> + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout> -> vector<32x32xf32> + // CHECK: %[[REDUCE:.*]] = vector.multi_reduction , {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32> + // CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32> + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3> + // CHECK: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32> + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C4_1:.*]] = arith.constant 4 : index + // CHECK: %[[ID_Y:.*]] = affine.apply #map()[%[[SGID]]] + // CHECK: %[[ID_X:.*]] = affine.apply #map1()[%[[SGID]]] + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[L_OFF_X:.*]] = index.mul %[[ID_X]], %[[C32]] + // CHECK: xegpu.store_matrix {{.*}}, %[[MDESC]][%[[ID_Y]], %[[L_OFF_X]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index + // CHECK: gpu.barrier + // CHECK: %[[C8_1:.*]] = arith.constant 8 : index + // CHECK: %[[OFF_Y:.*]] = index.mul %[[ID_Y]], %[[C8_1]] + // CHECK: %[[C4_2:.*]] = arith.constant 4 : index + // CHECK: %[[OFF_X:.*]] = index.mul %[[ID_X]], %[[C4_2]] + // CHECK: %[[LOAD:.*]] = xegpu.load_matrix %[[MDESC]][%[[OFF_Y]], %[[OFF_X]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x4xf32> + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[REDUCE:.*]] = vector.multi_reduction , %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32> %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> From b1e202693d5c665800e67d7ff6cd8b0fe2146d82 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 3 Sep 2025 21:09:11 +0000 Subject: [PATCH 3/4] Clean up --- .../Transforms/XeGPUWgToSgDistribute.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 81dcb5d931473..685a9da92e54c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -981,15 +981,13 @@ struct WgToSgMultiDimReductionOp rewriter.setInsertionPoint(op); - // Get layout of the source tensor - SmallVector sgLayoutParent = - sliceAttr.getParent().getSgLayoutAsInt(); + SmallVector sgLayout = sliceAttr.getParent().getSgLayoutAsInt(); // Allocate SLM auto bitWidth = elemTy.getIntOrFloatBitWidth(); auto flattenFactor = bitWidth / 8; auto slmSize = - resType.getNumElements() * sgLayoutParent[reduceDim] * flattenFactor; + resType.getNumElements() * sgLayout[reduceDim] * flattenFactor; auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3); auto slm = rewriter.create(loc, slmTy); @@ -1000,11 +998,11 @@ struct WgToSgMultiDimReductionOp srcVecType ? srcVecType.getShape() : ArrayRef(); for (size_t i = 0; i < srcShape.size(); ++i) { if (static_cast(i) == reduceDim) { - // For the reduced dimension, use sgLayoutParent[i] - memDescShape.push_back(sgLayoutParent[i]); + // For the reduced dimension, use sgLayout[i] + memDescShape.push_back(sgLayout[i]); } else { - // For other dimensions, multiply sgLayoutParent[i] by sgShape[i] - memDescShape.push_back(sgLayoutParent[i] * srcShape[i]); + // For other dimensions, multiply sgLayout[i] by sgShape[i] + memDescShape.push_back(sgLayout[i] * srcShape[i]); } } @@ -1021,11 +1019,11 @@ struct WgToSgMultiDimReductionOp auto sgId = rewriter.create(loc, rewriter.getIndexType(), nullptr); - SmallVector srcSgLayoutDim(sgLayoutParent.size()); + SmallVector srcSgLayoutDim(sgLayout.size()); - for (size_t i = 0; i < sgLayoutParent.size(); i++) { + for (size_t i = 0; i < sgLayout.size(); i++) { srcSgLayoutDim[i] = - arith::ConstantIndexOp::create(rewriter, loc, sgLayoutParent[i]); + arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]); } auto sgIdVec = @@ -1036,7 +1034,7 @@ struct WgToSgMultiDimReductionOp // Calculate offsets for store_matrix SmallVector slmStoreOffsets; - for (size_t i = 0; i < sgLayoutParent.size(); ++i) { + for (size_t i = 0; i < sgLayout.size(); ++i) { Value offset = rewriter.createOrFold( loc, sgIds[i], arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i])); @@ -1059,14 +1057,14 @@ struct WgToSgMultiDimReductionOp if (static_cast(i) == reduceDim) { slmLoadShape.push_back(memDescShape[i]); } else { - int64_t divisor = computeProduct(sgLayoutParent); + int64_t divisor = computeProduct(sgLayout); slmLoadShape.push_back(memDescShape[i] / divisor); } } // Calculate offsets for load_matrix op SmallVector slmLoadOffsets; - for (size_t i = 0; i < sgLayoutParent.size(); ++i) { + for (size_t i = 0; i < sgLayout.size(); ++i) { Value offset = rewriter.createOrFold( loc, sgIds[i], arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i])); From ff3baed020510463408539c6e17a942f4a0f2353 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 4 Sep 2025 17:34:29 +0000 Subject: [PATCH 4/4] clean up --- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 8934865c49cd3..fb1eff1ae8c07 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -370,8 +370,8 @@ gpu.module @test_distribution { // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @vector_reduce(%src: memref<256x128xf32>) { // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32> - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> - // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout> -> vector<32x32xf32> + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> // CHECK: %[[REDUCE:.*]] = vector.multi_reduction , {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32> // CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32> // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3> @@ -396,9 +396,9 @@ gpu.module @test_distribution { // CHECK: %[[REDUCE:.*]] = vector.multi_reduction , %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32> %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> - -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> -> vector<256x128xf32> %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>