diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9f627c7e1e6d8..685a9da92e54c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -726,7 +726,6 @@ struct UnrealizedConversionCastOpPattern } }; -// This pattern distributes arith.constant op into subgroup-level constants struct WgToSgArithConstantOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -756,8 +755,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (auto newLayout = layout.dropSgLayoutAndData()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + if (auto sliceAttr = dyn_cast_if_present(layout)) { + if (sliceAttr.isForSubgroup()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), + sliceAttr.dropSgLayoutAndData()); + } else if (auto layoutAttr = + dyn_cast_if_present(layout)) { + if (auto newLayout = layoutAttr.dropSgLayoutAndData()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + } SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); @@ -919,6 +925,189 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern { } }; +// Pattern to distribute vector.multi_dim_reduction op to subgroup level. +struct WgToSgMultiDimReductionOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType srcType = dyn_cast(op.getSource().getType()); + VectorType accType = dyn_cast(op.getAcc().getType()); + VectorType resType = dyn_cast(op.getResult().getType()); + Type elemTy = srcType.getElementType(); + if (!srcType || !accType || !resType) + return failure(); + + // Support only 2D vectors + if (srcType.getShape().size() != 2 && resType.getShape().size() != 1) + return failure(); + ArrayRef wgShape = resType.getShape(); + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto sliceAttr = op->getAttrOfType(layoutName); + if (!sliceAttr || sliceAttr.getRank() != 1) + return failure(); + + SmallVector dims = + llvm::to_vector(sliceAttr.getDims().asArrayRef()); + SmallVector sgShape = getSgShapeAndCount(wgShape, sliceAttr).first; + + int64_t reduceDim = dims[0]; + + // Step 1: Subgroup-level reduction + // Each subgroup reduces its local tile. + SmallVector newReductions; + VectorType newType = VectorType::get(sgShape, srcType.getElementType()); + SmallVector shapeCastShape = sgShape; + if (reduceDim == 0) + shapeCastShape.insert(shapeCastShape.begin(), 1); + else + shapeCastShape.push_back(1); + for (auto [sgSrc, sgAcc] : + llvm::zip(adaptor.getSource(), adaptor.getAcc())) { + auto sgReduce = rewriter.create( + op.getLoc(), newType, op.getKind(), sgSrc, sgAcc, + op.getReductionDims()); + // Compute the shape for the shape cast: set reducedDim to 1, keep other + // dims as sgShape + auto shapeCastTy = + VectorType::get(shapeCastShape, srcType.getElementType()); + auto shapeCast = rewriter.create( + op.getLoc(), shapeCastTy, sgReduce.getResult()); + newReductions.push_back(shapeCast.getResult()); + } + + rewriter.setInsertionPoint(op); + + SmallVector sgLayout = sliceAttr.getParent().getSgLayoutAsInt(); + + // Allocate SLM + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + auto flattenFactor = bitWidth / 8; + auto slmSize = + resType.getNumElements() * sgLayout[reduceDim] * flattenFactor; + auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3); + auto slm = rewriter.create(loc, slmTy); + + // Create a SLM buffer using xegpu.create_mem_desc + SmallVector memDescShape; + auto srcVecType = dyn_cast(adaptor.getSource()[0].getType()); + ArrayRef srcShape = + srcVecType ? srcVecType.getShape() : ArrayRef(); + for (size_t i = 0; i < srcShape.size(); ++i) { + if (static_cast(i) == reduceDim) { + // For the reduced dimension, use sgLayout[i] + memDescShape.push_back(sgLayout[i]); + } else { + // For other dimensions, multiply sgLayout[i] by sgShape[i] + memDescShape.push_back(sgLayout[i] * srcShape[i]); + } + } + + auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), + memDescShape, elemTy, nullptr); + auto memDesc = + rewriter.create(loc, memDescType, slm); + + // Step 2: Store subgroup results to SLM (shared local memory) + // SLM layout: sgLayout same as srcLayout, sgData is shapeCastShape + SmallVector slmSgData = shapeCastShape; + + // Get subgroup id and delinearize + auto sgId = rewriter.create(loc, rewriter.getIndexType(), + nullptr); + + SmallVector srcSgLayoutDim(sgLayout.size()); + + for (size_t i = 0; i < sgLayout.size(); i++) { + srcSgLayoutDim[i] = + arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]); + } + + auto sgIdVec = + affine::delinearizeIndex(rewriter, loc, sgId, srcSgLayoutDim); + if (failed(sgIdVec)) + return failure(); + SmallVector sgIds = *sgIdVec; + + // Calculate offsets for store_matrix + SmallVector slmStoreOffsets; + for (size_t i = 0; i < sgLayout.size(); ++i) { + Value offset = rewriter.createOrFold( + loc, sgIds[i], + arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i])); + slmStoreOffsets.push_back(offset); + } + + // Store subgroup result to SLM + rewriter.create( + loc, newReductions[0], memDesc.getResult(), + ArrayRef(slmStoreOffsets), + /*layout=*/nullptr); + + // Barrier to synchronize subgroups + rewriter.create(loc); + + // Step 3: Load from SLM for the second reduction + SmallVector slmLoadShape; + + for (size_t i = 0; i < memDescShape.size(); ++i) { + if (static_cast(i) == reduceDim) { + slmLoadShape.push_back(memDescShape[i]); + } else { + int64_t divisor = computeProduct(sgLayout); + slmLoadShape.push_back(memDescShape[i] / divisor); + } + } + + // Calculate offsets for load_matrix op + SmallVector slmLoadOffsets; + for (size_t i = 0; i < sgLayout.size(); ++i) { + Value offset = rewriter.createOrFold( + loc, sgIds[i], + arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i])); + slmLoadOffsets.push_back(offset); + } + + auto load = rewriter.create( + loc, VectorType::get(slmLoadShape, elemTy), memDesc, + llvm::ArrayRef({slmLoadOffsets}), + /*layout=*/nullptr); + + // Step 4: Create a constant accumulator for the second reduction + // with same value as adaptor.getAcc()[0] and shape set to + // the non reduce dimension of load + auto accShape = load.getType().getShape(); + SmallVector accShapeWithoutReduceDim; + for (size_t i = 0; i < accShape.size(); ++i) { + if (static_cast(i) != reduceDim) + accShapeWithoutReduceDim.push_back(accShape[i]); + } + auto accTy = VectorType::get(accShapeWithoutReduceDim, elemTy); + auto accConstOp = adaptor.getAcc()[0].getDefiningOp(); + Attribute accSplatValue; + if (auto denseAttr = dyn_cast_or_null( + accConstOp ? accConstOp.getValue() : nullptr)) { + accSplatValue = + denseAttr.isSplat() ? denseAttr.getSplatValue() : nullptr; + } + if (!accSplatValue) + return failure(); + auto accValue = rewriter.create( + loc, accTy, DenseElementsAttr::get(accTy, accSplatValue)); + // Step 5: Perform the second reduction + VectorType secondReduceVecType = + VectorType::get(accShapeWithoutReduceDim, srcType.getElementType()); + auto secondReduce = rewriter.create( + loc, secondReduceVecType, op.getKind(), load, accValue, + op.getReductionDims()); + rewriter.replaceOpWithMultiple(op, {secondReduce.getResult()}); + return success(); + } +}; + } // namespace namespace mlir { @@ -932,7 +1121,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, - WgToSgStoreMatrixOp>(patterns.getContext()); + WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1107,6 +1297,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); + target.addDynamicallyLegalOp( + [=](vector::MultiDimReductionOp op) -> bool { + // Check if the layout is legal + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](UnrealizedConversionCastOp op) { return llvm::is_contained(existingCastOps, op.getOperation()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index afb2bf876c18f..fb1eff1ae8c07 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -365,4 +365,43 @@ gpu.module @test_distribution { xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> gpu.return } + + //CHECK-LABEL: vector_reduce + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @vector_reduce(%src: memref<256x128xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32> + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> + // CHECK: %[[REDUCE:.*]] = vector.multi_reduction , {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32> + // CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32> + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3> + // CHECK: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32> + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C4_1:.*]] = arith.constant 4 : index + // CHECK: %[[ID_Y:.*]] = affine.apply #map()[%[[SGID]]] + // CHECK: %[[ID_X:.*]] = affine.apply #map1()[%[[SGID]]] + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[L_OFF_X:.*]] = index.mul %[[ID_X]], %[[C32]] + // CHECK: xegpu.store_matrix {{.*}}, %[[MDESC]][%[[ID_Y]], %[[L_OFF_X]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index + // CHECK: gpu.barrier + // CHECK: %[[C8_1:.*]] = arith.constant 8 : index + // CHECK: %[[OFF_Y:.*]] = index.mul %[[ID_Y]], %[[C8_1]] + // CHECK: %[[C4_2:.*]] = arith.constant 4 : index + // CHECK: %[[OFF_X:.*]] = index.mul %[[ID_X]], %[[C4_2]] + // CHECK: %[[LOAD:.*]] = xegpu.load_matrix %[[MDESC]][%[[OFF_Y]], %[[OFF_X]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x4xf32> + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: %[[REDUCE:.*]] = vector.multi_reduction , %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32> + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] + : vector<256x128xf32> to vector<128xf32> + gpu.return + } }