-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg #170936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1152,64 +1152,232 @@ struct WgToSgVectorShapeCastOp | |
| } | ||
| }; | ||
|
|
||
| /// Pattern for lowering vector.multi_reduction op to subgroup level. | ||
| /// Current limitation: the sg_layout in the reduced dimension being 1 | ||
| /// so that reduction is local to subgroup & no cross-subgroup communication is | ||
| /// needed. | ||
| /// TODO: Add cases to handle more general situations which require SLM access. | ||
| // This pattern transforms vector.multi_dim_reduction ops to work at subgroup | ||
| // level. | ||
| struct WgToSgMultiDimReductionOp | ||
| : public OpConversionPattern<vector::MultiDimReductionOp> { | ||
| using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| Location loc = op.getLoc(); | ||
|
|
||
| VectorType srcType = op.getSourceVectorType(); | ||
| VectorType dstType = dyn_cast<VectorType>(op.getResult().getType()); | ||
| if (!dstType) | ||
| return failure(); | ||
|
|
||
| auto srcShape = srcType.getShape(); | ||
| auto originalSrcShape = srcType.getShape(); | ||
| xegpu::DistributeLayoutAttr layout = | ||
| xegpu::getDistributeLayoutAttr(op.getResult()); | ||
|
|
||
| if (!layout || !layout.isForWorkgroup()) | ||
| return failure(); | ||
|
|
||
| auto reductionDims = llvm::to_vector(op.getReductionDims()); | ||
| if (reductionDims.size() != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Only single dimension reduction is supported"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What prevents 2D reductions here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its one of the requirements for xegpu canonical form ..that pass should ensure it is only single dim reduction here
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But then we face a problem. If there is a 2D test case, then we have to rewrite it as two 1D reductions first. From what I see, this pattern naturally supports intra-sg reduction or further handles cross-sg results. If we were to consider 2D case, the pattern already has a most of the components for the hardcoded logic: do intra-sg reduction and then cross-sg via SLM. We do not care how "2D" is to be represented at lower levels. When we go lower and start to actually care how sg-local 2D reduction is executed, we have to do two 1D reductions. We decide on the order based on the layout (we first reduce the dimension that does not require shuffles, if any). However, if we are forced to split 2D reduction into two 1D reductions at wg level, we lose the ability to reason about the better order, because we do not require lane layout at WG level and cannot use it when splitting. Please correct me if I missed something.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The restriction/requirement is driven by implementation, not from users. So if our implementation can be improved to lift the restriction, we should try.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @akroviakov. We should handle multiple dims here. but for now this is fine. |
||
|
|
||
| // Get sg_layout and sg_data from the parent layout | ||
| SmallVector<int64_t> sgLayout; | ||
| SmallVector<int64_t> sgData; | ||
| if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) { | ||
| sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt(); | ||
| sgData = sliceAttr.getParent().getEffectiveSgDataAsInt(); | ||
| } else | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Reduction should have SliceAttr layout"); | ||
|
|
||
| Type elemTy = dstType.getElementType(); | ||
|
|
||
| // Step 1: perform local subgroup reductions with ZERO accumulator | ||
| SmallVector<Value> localReductions; | ||
| auto sources = adaptor.getSource(); | ||
| auto accs = adaptor.getAcc(); | ||
|
|
||
| SmallVector<Value> expandedAccs; | ||
| if (accs.size() == 1 && sources.size() > 1) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this case? |
||
| for (size_t i = 0; i < sources.size(); ++i) | ||
| expandedAccs.push_back(accs[0]); | ||
| } else | ||
| expandedAccs = llvm::to_vector(accs); | ||
|
|
||
| SmallVector<int64_t> sgShape = | ||
| getSgShapeAndCount(originalSrcShape, layout).first; | ||
| VectorType newDstType = VectorType::get({sgShape}, elemTy); | ||
| for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) { | ||
| // Create ZERO accumulator for local reduction | ||
| auto zeroLocalAcc = arith::ConstantOp::create( | ||
| rewriter, loc, newDstType, | ||
| DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy))); | ||
| // Local reduction with ZERO accumulator | ||
| auto localReduce = vector::MultiDimReductionOp::create( | ||
| rewriter, loc, newDstType, op.getKind(), sgSrc, | ||
| zeroLocalAcc.getResult(), reductionDims); | ||
| localReductions.push_back(localReduce.getResult()); | ||
| } | ||
|
|
||
| SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout) | ||
| .getParent() | ||
| .getEffectiveSgLayoutAsInt(); | ||
| SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout) | ||
| .getParent() | ||
| .getEffectiveSgDataAsInt(); | ||
|
|
||
| // Check that the sgLayout in the reduced dimension is 1 and | ||
| // each sg gets the entire slice to reduce. | ||
| for (int64_t dim : reductionDims) { | ||
| if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim]) | ||
| return rewriter.notifyMatchFailure( | ||
| op, | ||
| "sgLayout in each reduced dimension must be 1 and sgData in the " | ||
| "reduced dim must match srcShape in that dim"); | ||
| // Check if cross-subgroup reduction is needed | ||
| int64_t reductionDim = reductionDims[0]; | ||
| bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1); | ||
|
|
||
| // If no cross-subgroup reduction needed, add accumulator and return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code could use some helper functions so the main functions becomes shorter. |
||
| if (!needsCrossSubgroupReduction) { | ||
| SmallVector<Value> results; | ||
| for (auto localResult : localReductions) { | ||
| auto finalResult = arith::AddFOp::create(rewriter, loc, localResult, | ||
| adaptor.getAcc()[0]); | ||
| if (auto defOp = finalResult.getResult().getDefiningOp()) | ||
| xegpu::setDistributeLayoutAttr(defOp->getResult(0), | ||
| layout.dropSgLayoutAndData()); | ||
| results.push_back(finalResult.getResult()); | ||
| } | ||
| rewriter.replaceOpWithMultiple(op, {results}); | ||
| return success(); | ||
| } | ||
|
|
||
| SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first; | ||
| // Step 2: Cross-subgroup reduction using SLM | ||
|
|
||
| VectorType newDstType = | ||
| VectorType::get({sgShape}, dstType.getElementType()); | ||
| // Calculate total elements in local result | ||
| int64_t localElements = computeProduct(sgShape); | ||
|
|
||
| SmallVector<Value> newReductions; | ||
| for (auto sgSrc : adaptor.getSource()) { | ||
| auto newOp = vector::MultiDimReductionOp::create( | ||
| rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, | ||
| adaptor.getAcc()[0], op.getReductionDims()); | ||
| xegpu::setDistributeLayoutAttr(newOp->getResult(0), | ||
| layout.dropSgLayoutAndData()); | ||
| newReductions.push_back(newOp.getResult()); | ||
| // Shape cast for SLM storage - store as [1, localElements] | ||
| SmallVector<int64_t> storeShape2D = {1, localElements}; | ||
| VectorType storeType2D = VectorType::get(storeShape2D, elemTy); | ||
| auto storeShapeCast = vector::ShapeCastOp::create( | ||
| rewriter, loc, storeType2D, localReductions[0]); | ||
| Value storeData = storeShapeCast.getResult(); | ||
|
|
||
| // Calculate SLM shape | ||
| int64_t totalReductionSubgroups = | ||
| sgLayout[static_cast<size_t>(reductionDims[0])]; | ||
|
|
||
| // Total result elements across all subgroups in non-reduction dimensions | ||
| int64_t totalResultElements = localElements; | ||
| for (size_t i = 0; i < sgLayout.size(); ++i) { | ||
| if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) | ||
| totalResultElements *= sgLayout[i]; | ||
| } | ||
|
Comment on lines
+1258
to
+1262
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can simplify with computeProduct thing and divide with reductionDim size. |
||
|
|
||
| SmallVector<int64_t> slmShape2D = {totalReductionSubgroups, | ||
| totalResultElements}; | ||
|
|
||
| // Allocate SLM | ||
| auto bitWidth = elemTy.getIntOrFloatBitWidth(); | ||
| auto bytesPerElement = bitWidth / 8; | ||
| int64_t slmElements = slmShape2D[0] * slmShape2D[1]; | ||
| auto slmSize = slmElements * bytesPerElement; | ||
| auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3); | ||
| auto slm = memref::AllocaOp::create(rewriter, loc, slmTy); | ||
|
|
||
| auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), | ||
| slmShape2D, elemTy, nullptr); | ||
| auto memDesc = | ||
| xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm); | ||
|
|
||
| // Step 4: Store local results to SLM | ||
| auto sgId = gpu::SubgroupIdOp::create(rewriter, loc, | ||
| rewriter.getIndexType(), nullptr); | ||
|
|
||
| // Convert sgLayout to Values for delinearizeIndex | ||
| SmallVector<Value> sgLayoutValues; | ||
| for (int64_t dim : sgLayout) | ||
| sgLayoutValues.push_back( | ||
| arith::ConstantIndexOp::create(rewriter, loc, dim)); | ||
|
|
||
| auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(), | ||
| sgLayoutValues); | ||
| if (failed(sgIdsResult)) | ||
| return failure(); | ||
| SmallVector<Value> sgIds = *sgIdsResult; | ||
|
|
||
| // Row offset is simply the subgroup ID along the reduction dimension | ||
| Value rowOffset = sgIds[reductionDim]; | ||
|
|
||
| // Column offset: linearize all non-reduction dimensions and multiply by | ||
| // localElements | ||
| Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); | ||
| int64_t currentStride = 1; | ||
| for (size_t i = 0; i < sgLayout.size(); ++i) { | ||
| if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension | ||
| Value dimVal = sgIds[i]; | ||
| Value strideVal = | ||
| arith::ConstantIndexOp::create(rewriter, loc, currentStride); | ||
| Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal); | ||
| colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term); | ||
| currentStride *= sgLayout[i]; | ||
| } | ||
| } | ||
| Value localElementsVal = | ||
| arith::ConstantIndexOp::create(rewriter, loc, localElements); | ||
| colOffset = | ||
| arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal); | ||
|
|
||
| SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset}; | ||
|
|
||
| xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(), | ||
| storeOffsets2D, /*layout=*/nullptr); | ||
|
|
||
| gpu::BarrierOp::create(rewriter, loc); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To sync producer and consumer sg for data, both barrier and fence are needed. |
||
|
|
||
| // Step 5: Load from SLM for final reduction | ||
| SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements}; | ||
| VectorType loadType2D = VectorType::get(loadShape2D, elemTy); | ||
|
|
||
| // Load offsets - each subgroup loads its column based on non-reduction | ||
| // position | ||
| Value loadOffsetY = arith::ConstantIndexOp::create(rewriter, loc, 0); | ||
| Value loadOffsetX = colOffset; | ||
|
|
||
| SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX}; | ||
|
|
||
| auto loadOp = xegpu::LoadMatrixOp::create( | ||
| rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D, | ||
| /*layout=*/nullptr); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need a barrier here as well to make sure everyone finish loading the values? |
||
|
|
||
| // Step 6: Perform final reduction with ZERO accumulator | ||
| SmallVector<int64_t> finalReductionDims = {0}; | ||
| SmallVector<int64_t> finalResultShape = {localElements}; | ||
| VectorType finalResultType = VectorType::get(finalResultShape, elemTy); | ||
|
|
||
| // Create ZERO accumulator for final reduction | ||
| auto zeroFinalAcc = arith::ConstantOp::create( | ||
| rewriter, loc, finalResultType, | ||
| DenseElementsAttr::get(finalResultType, rewriter.getZeroAttr(elemTy))); | ||
|
|
||
| auto finalReduce = vector::MultiDimReductionOp::create( | ||
| rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(), | ||
| zeroFinalAcc.getResult(), finalReductionDims); | ||
|
|
||
| // Step 7: Add the original accumulator at the end | ||
| Value originalAcc = adaptor.getAcc()[0]; | ||
| Value accToAdd = originalAcc; | ||
|
|
||
| // Handle shape mismatch by shape casting | ||
| if (originalAcc.getType() != finalReduce.getResult().getType()) { | ||
| auto originalAccType = cast<VectorType>(originalAcc.getType()); | ||
| auto finalResultType = | ||
| cast<VectorType>(finalReduce.getResult().getType()); | ||
|
|
||
| // If they have the same number of elements, just shape cast | ||
| if (originalAccType.getNumElements() == | ||
| finalResultType.getNumElements()) { | ||
| auto shapeCast = vector::ShapeCastOp::create( | ||
| rewriter, loc, finalResultType, originalAcc); | ||
| accToAdd = shapeCast.getResult(); | ||
| } | ||
| } | ||
|
|
||
| auto finalResult = | ||
| arith::AddFOp::create(rewriter, loc, finalReduce.getResult(), accToAdd); | ||
|
|
||
| if (auto defOp = finalResult.getResult().getDefiningOp()) | ||
| xegpu::setDistributeLayoutAttr(defOp->getResult(0), | ||
| layout.dropSgLayoutAndData()); | ||
|
|
||
| rewriter.replaceOpWithMultiple(op, {newReductions}); | ||
| rewriter.replaceOp(op, finalResult.getResult()); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add the summary of your algo here.