From eaaca7f54a9333b1841283b4483cb9c8f91f9f6b Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 16:42:52 +0000 Subject: [PATCH 01/10] save --- .../Vector/Transforms/VectorDistribute.cpp | 242 +++++++++++++++--- 1 file changed, 213 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index be0d28a91cba7..2d9fcaee37282 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,13 +15,19 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include #include using namespace mlir; @@ -939,8 +945,40 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { } }; +static VectorType +tryFindDistributedType(TypedValue source, + WarpExecuteOnLane0Op warpOp, + const DistributionMapFn &distributionMapFn) { + VectorType distributedType = source.getType(); + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == source; + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedType = + cast(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(source); + VectorType computed = + getDistributedType(source.getType(), map, warpOp.getWarpSize()); + if (!computed) + return source.getType(); + distributedType = computed; + } + return distributedType; +} + struct WarpOpBroadcast : public WarpDistributionPattern { - using Base::Base; + WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -953,18 +991,23 @@ struct WarpOpBroadcast : public WarpDistributionPattern { auto destVecType = cast(warpOp->getResultTypes()[operandNumber]); Value broadcastSrc = broadcastOp.getSource(); - Type broadcastSrcType = broadcastSrc.getType(); + Type srcDistributedType = broadcastSrc.getType(); + + if (isa(srcDistributedType)) + srcDistributedType = + tryFindDistributedType(cast>(broadcastSrc), + warpOp, distributionMapFn); // Check that the broadcast actually spans a set of values uniformly across // all threads. In other words, check that each thread can reconstruct // their own broadcast. // For that we simply check that the broadcast we want to build makes sense. - if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != + if (vector::isBroadcastableTo(srcDistributedType, destVecType) != vector::BroadcastableToResult::Success) return failure(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); + rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = vector::BroadcastOp::create( rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); @@ -972,49 +1015,83 @@ struct WarpOpBroadcast : public WarpDistributionPattern { broadcasted); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. struct WarpOpShapeCast : public WarpDistributionPattern { - using Base::Base; + + WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); if (!operand) return failure(); - auto oldCastOp = operand->get().getDefiningOp(); unsigned int operandNumber = operand->getOperandNumber(); - auto castDistributedType = + VectorType sourceType = oldCastOp.getSourceVectorType(); + VectorType distributedResultType = cast(warpOp->getResultTypes()[operandNumber]); - VectorType castOriginalType = oldCastOp.getSourceVectorType(); - VectorType castResultType = castDistributedType; - - // We expect the distributed type to have a smaller rank than the original - // type. Prepend with size-one dimensions to make them the same. - unsigned castDistributedRank = castDistributedType.getRank(); - unsigned castOriginalRank = castOriginalType.getRank(); - if (castDistributedRank < castOriginalRank) { - SmallVector shape(castOriginalRank - castDistributedRank, 1); - llvm::append_range(shape, castDistributedType.getShape()); - castDistributedType = - VectorType::get(shape, castDistributedType.getElementType()); + VectorType distributedSourceType = sourceType; + bool isResultDistributed = distributedResultType.getNumElements() < + oldCastOp.getResultVectorType().getNumElements(); + + // If the result is not distributed, source distribted type is the same + // as the source type. If the result is distributed, we need to compute the + // distributed source type according to following rules: + // 1. If the source type is yielded from the warp op, we can use the + // matching warp result type as the distributed source type. + // 2. If the source type is not yielded from the warp op, we need + // to compute the distributed source type based on the distribution map + // and the warp size. + if (isResultDistributed) { + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = + llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == oldCastOp.getSource(); + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedSourceType = + cast(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(oldCastOp.getSource()); + distributedSourceType = + getDistributedType(sourceType, map, warpOp.getWarpSize()); + if (!distributedSourceType) + return rewriter.notifyMatchFailure( + oldCastOp, + "cannot compute distributed source type for shape cast"); + } } SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, + rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value newCast = vector::ShapeCastOp::create( - rewriter, oldCastOp.getLoc(), castResultType, + rewriter, oldCastOp.getLoc(), distributedResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Sink out vector.create_mask op feeding into a warp op yield. @@ -1996,6 +2073,114 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; +struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { + VectorMultiDimReductionDistribution(MLIRContext *context, + PatternBenefit benefit = 1) + : WarpDistributionPattern(context, benefit) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred); + if (!yieldOperand) + return failure(); + auto reductionOp = + cast(yieldOperand->get().getDefiningOp()); + unsigned operandNumber = yieldOperand->getOperandNumber(); + VectorType sourceType = reductionOp.getSourceVectorType(); + VectorType distributedResultType = + cast(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + // Only 2D vectors are supported. + if (sourceType.getRank() != 2) + return rewriter.notifyMatchFailure(warpOp, + "Only 2D reductions are supported."); + ArrayRef reductionDims = reductionOp.getReductionDims(); + // Only 1 reduction dimension supported. + if (reductionDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Only 1 reduction dimension is supported."); + + // Col reduction. + if (reductionDims[0] == 0) { + // Yield the source vector and the accumulator. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + SmallVector shape(sourceType.getShape()); + shape[1] = shape[1] / warpOp.getWarpSize(); + auto sourceDistributedType = VectorType::get(shape, elementType); + SmallVector newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, + {sourceDistributedType, distributedResultType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + // Create new reduction op. + // auto newOp = vector::MultiDimReductionOp::create( + // rewriter, reductionOp.getLoc(), distributedResultType, + // reductionOp.getKind(), + // /** source = **/ newWarpOp.getResult(newRetIndices[0]), + // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]), + // reductionDims); + // Create a constant zero value for storing the reduction result. + // rewriter.setInsertionPointAfter(reductionOp); + auto zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + int nCols = sourceDistributedType.getShape()[1]; + Value source = newWarpOp.getResult(newRetIndices[0]); + Value acc = newWarpOp.getResult(newRetIndices[1]); + for (int i = 0; i < nCols; ++i) { + Value col = vector::ExtractStridedSliceOp::create( + rewriter, reductionOp.getLoc(), source, {0, i}, + {sourceDistributedType.getShape()[0], 1}, {1, 1}); + col = vector::ShapeCastOp::create( + rewriter, reductionOp.getLoc(), + VectorType::get({sourceDistributedType.getShape()[0]}, elementType), + col); + Value accCol = + vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); + Value colReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); + // Insert the reduced column into the result. + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + colReduce, result, i); + } + // Replace the warp op result with the new reduction op. + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); + return success(); + } + // Row reduction. + // Create a constant zero value for storing the reduction result. + rewriter.setInsertionPointAfter(reductionOp); + auto zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + // Value result = arith::ConstantOp::create( + // rewriter, reductionOp.getLoc(), + // rewriter.getIntegerAttr(reductionOp.getType(), 0)); + int nRows = sourceType.getShape()[0]; + // For each row, do a vector reduction. + for (int i = 0; i < nRows; ++i) { + Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getSource(), i); + Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getAcc(), i); + Value rowReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + rowReduce, result, i); + } + // Replace the warp op result with the final result. + rewriter.replaceAllUsesWith(reductionOp.getResult(), result); + + return success(); + } +}; + } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( @@ -2016,16 +2201,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns - .add( - patterns.getContext(), benefit); + patterns.add( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add(patterns.getContext(), distributionMapFn, - benefit); + patterns.add( + patterns.getContext(), distributionMapFn, benefit); } void mlir::vector::populateDistributeReduction( From 56c3441e9443660788e51064f8206c5e4ac9fbaf Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:07:10 +0000 Subject: [PATCH 02/10] save --- .../Vector/Transforms/VectorDistribute.cpp | 110 +++++------------ .../Vector/vector-warp-distribute.mlir | 111 ++++++++++++++++++ 2 files changed, 143 insertions(+), 78 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 2d9fcaee37282..6410a895fc9ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -945,40 +945,8 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { } }; -static VectorType -tryFindDistributedType(TypedValue source, - WarpExecuteOnLane0Op warpOp, - const DistributionMapFn &distributionMapFn) { - VectorType distributedType = source.getType(); - // Check if the source is yielded from the warp op. - gpu::YieldOp yieldOp = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { - return operand.get() == source; - }); - - if (it != yieldOp->getOpOperands().end()) { - // If the source is yielded from the warp op, we can use the matching - // warp result type as the distributed source type. - distributedType = - cast(warpOp->getResultTypes()[it->getOperandNumber()]); - } else { - // If the source is not yielded from the warp op, we need to compute - // the distributed source type based on the distribution map and the - // warp size. - AffineMap map = distributionMapFn(source); - VectorType computed = - getDistributedType(source.getType(), map, warpOp.getWarpSize()); - if (!computed) - return source.getType(); - distributedType = computed; - } - return distributedType; -} - struct WarpOpBroadcast : public WarpDistributionPattern { - WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) - : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -991,23 +959,18 @@ struct WarpOpBroadcast : public WarpDistributionPattern { auto destVecType = cast(warpOp->getResultTypes()[operandNumber]); Value broadcastSrc = broadcastOp.getSource(); - Type srcDistributedType = broadcastSrc.getType(); - - if (isa(srcDistributedType)) - srcDistributedType = - tryFindDistributedType(cast>(broadcastSrc), - warpOp, distributionMapFn); + Type broadcastSrcType = broadcastSrc.getType(); // Check that the broadcast actually spans a set of values uniformly across // all threads. In other words, check that each thread can reconstruct // their own broadcast. // For that we simply check that the broadcast we want to build makes sense. - if (vector::isBroadcastableTo(srcDistributedType, destVecType) != + if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != vector::BroadcastableToResult::Success) return failure(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices); + rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = vector::BroadcastOp::create( rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); @@ -1015,9 +978,6 @@ struct WarpOpBroadcast : public WarpDistributionPattern { broadcasted); return success(); } - -private: - DistributionMapFn distributionMapFn; }; /// Pattern to move shape cast out of the warp op. shape cast is basically a @@ -2100,37 +2060,37 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { return rewriter.notifyMatchFailure( warpOp, "Only 1 reduction dimension is supported."); + // Create a constant vector to store the result of the reduction per lane. + TypedAttr zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + // Col reduction. if (reductionDims[0] == 0) { - // Yield the source vector and the accumulator. + // Source vector must be distributable to lanes in the col dimension. if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) return rewriter.notifyMatchFailure( warpOp, "Source vector dimension must be divisible by warp size."); + // Compute source distributed type. SmallVector shape(sourceType.getShape()); shape[1] = shape[1] / warpOp.getWarpSize(); auto sourceDistributedType = VectorType::get(shape, elementType); + + // Yield the source and acc vectors from the WarpOp. SmallVector newRetIndices; auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, {sourceDistributedType, distributedResultType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - // Create new reduction op. - // auto newOp = vector::MultiDimReductionOp::create( - // rewriter, reductionOp.getLoc(), distributedResultType, - // reductionOp.getKind(), - // /** source = **/ newWarpOp.getResult(newRetIndices[0]), - // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]), - // reductionDims); - // Create a constant zero value for storing the reduction result. - // rewriter.setInsertionPointAfter(reductionOp); - auto zeroAttr = - rewriter.getZeroAttr(distributedResultType.getElementType()); - Value result = arith::ConstantOp::create( - rewriter, reductionOp->getLoc(), distributedResultType, - DenseElementsAttr::get(distributedResultType, zeroAttr)); + int nCols = sourceDistributedType.getShape()[1]; Value source = newWarpOp.getResult(newRetIndices[0]); Value acc = newWarpOp.getResult(newRetIndices[1]); + // For each column owned by a lane, extract the column (of size nRows x + // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the + // result back to the result vector. for (int i = 0; i < nCols; ++i) { Value col = vector::ExtractStridedSliceOp::create( rewriter, reductionOp.getLoc(), source, {0, i}, @@ -2143,7 +2103,6 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); Value colReduce = vector::ReductionOp::create( rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); - // Insert the reduced column into the result. result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), colReduce, result, i); } @@ -2151,19 +2110,13 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); return success(); } - // Row reduction. - // Create a constant zero value for storing the reduction result. + // For row reductions, we simply rewrite the MultiReductionOp in terms of + // multiple ReductionOps. Actual distribution is done by the WarpOpReduction + // pattern. rewriter.setInsertionPointAfter(reductionOp); - auto zeroAttr = - rewriter.getZeroAttr(distributedResultType.getElementType()); - Value result = arith::ConstantOp::create( - rewriter, reductionOp->getLoc(), distributedResultType, - DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Value result = arith::ConstantOp::create( - // rewriter, reductionOp.getLoc(), - // rewriter.getIntegerAttr(reductionOp.getType(), 0)); int nRows = sourceType.getShape()[0]; - // For each row, do a vector reduction. + // For each row of the source, extract the row vector, do a reduction and, + // insert the result back to the result. for (int i = 0; i < nRows; ++i) { Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), reductionOp.getSource(), i); @@ -2201,15 +2154,16 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns.add( - patterns.getContext(), benefit); + patterns + .add( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add( - patterns.getContext(), distributionMapFn, benefit); + patterns.add(patterns.getContext(), + distributionMapFn, benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 4d2c964a6df3c..bf70fbbd27244 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) { return %r : f32 } +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce +// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { +// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> +// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 +// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> +// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<32x64xf32>) + %acc = "some_def"() : () -> (vector<64xf32>) + %1 = vector.multi_reduction , %0, %acc [0] : vector<32x64xf32> to vector<64xf32> + gpu.yield %1 : vector<64xf32> + } + return %r : vector<2xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce +// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 +// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 +// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 +// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 +// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 +// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 +// +// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 +// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 +// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 +// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 +// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 +// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { + %zero = arith.constant dense<0.0> : vector<2xf32> + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<2x32xf32>) + %1 = vector.multi_reduction , %0, %zero [1] : vector<2x32xf32> to vector<2xf32> + gpu.yield %1 : vector<2xf32> + } + return %r : vector<2xf32> +} + // ----- // CHECK-PROP-LABEL: func @warp_duplicate_yield( @@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) // CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32> // CHECK-PROP: return %[[CAST]] : vector<4xf32> +// ----- +func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) { + %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32> + %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32> + gpu.yield %3 : vector<32x64xf32> + } + return %r : vector<32x2xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d +// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32> +// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32> +// CHECK-PROP: return %[[CAST]] : vector<32x2xf32> + +// ----- +func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) { + %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32> + %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32> + gpu.yield %3 : vector<8x4x2xf32> + } + return %r : vector<8x4x2xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result +// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32> +// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32> +// CHECK-PROP: return %[[CAST]] : vector<8x4x2xf32> + // ----- func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> { From 01880b561e94c6cb752e6eddb16957e00dbdc97f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:26:49 +0000 Subject: [PATCH 03/10] save --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 6410a895fc9ae..8dc1418e09006 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2033,6 +2033,12 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; +// This patterns distribute the `vector.multi_reduction` operation across +// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes +// that source vector is distributed in column dimension (i.e. Each lane owns +// complete column(s) of the source vector. +// TODO: Add support for the case where source rows are distributed accross +// lanes. Requires DistributionMapFn to express the data distribution. struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { VectorMultiDimReductionDistribution(MLIRContext *context, PatternBenefit benefit = 1) From 53da9928117634d6eb929f81cbfa59ed4c06d884 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:28:13 +0000 Subject: [PATCH 04/10] save --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 8dc1418e09006..c88c001f34843 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2036,9 +2036,9 @@ struct WarpOpReduction : public WarpDistributionPattern { // This patterns distribute the `vector.multi_reduction` operation across // lanes in a warp. Currently only 2D to 1D reductions are supported and assumes // that source vector is distributed in column dimension (i.e. Each lane owns -// complete column(s) of the source vector. -// TODO: Add support for the case where source rows are distributed accross -// lanes. Requires DistributionMapFn to express the data distribution. +// complete column(s) of the source vector). +// TODO: Add support for the case where source rows are distributed across +// lanes. Requires `DistributionMapFn` to express the data distribution. struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { VectorMultiDimReductionDistribution(MLIRContext *context, PatternBenefit benefit = 1) From affd4aadb2e0f3f7cd19b0805b34067c1fa65371 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:48:08 +0000 Subject: [PATCH 05/10] save --- .../Vector/vector-warp-distribute.mlir | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bf70fbbd27244..bf0191655d654 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -879,44 +879,44 @@ func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { // ----- // CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce -// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32 -// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32 -// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32 -// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32 -// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { -// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> -// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 -// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 -// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 -// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 -// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 -// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 +// CHECK-PROP-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-PROP-DAG: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-PROP-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-PROP-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 +// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 +// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 +// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 +// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 +// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 // -// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 -// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 -// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 -// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 -// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 -// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 -// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> -// CHECK-PROP: return %[[R]] : vector<2xf32> +// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 +// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 +// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 +// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 +// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 +// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { %zero = arith.constant dense<0.0> : vector<2xf32> %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { From df59c20f5d8020ab9ba78f1c360334c738a60404 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 Aug 2025 00:01:32 +0000 Subject: [PATCH 06/10] save --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index c88c001f34843..b0b52919c69ce 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,19 +15,13 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" -#include #include using namespace mlir; From 55797318492b6a38801aa27bf9ec97d26523322e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 Aug 2025 23:31:35 +0000 Subject: [PATCH 07/10] save --- .../Vector/Transforms/VectorDistribute.cpp | 52 +++++++++++++------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index b0b52919c69ce..ab0f1b55d04da 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2033,9 +2033,8 @@ struct WarpOpReduction : public WarpDistributionPattern { // complete column(s) of the source vector). // TODO: Add support for the case where source rows are distributed across // lanes. Requires `DistributionMapFn` to express the data distribution. -struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { - VectorMultiDimReductionDistribution(MLIRContext *context, - PatternBenefit benefit = 1) +struct WarpOpMultiReduction : public WarpDistributionPattern { + WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1) : WarpDistributionPattern(context, benefit) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { @@ -2047,18 +2046,46 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { cast(yieldOperand->get().getDefiningOp()); unsigned operandNumber = yieldOperand->getOperandNumber(); VectorType sourceType = reductionOp.getSourceVectorType(); - VectorType distributedResultType = - cast(warpOp.getResult(operandNumber).getType()); - Type elementType = distributedResultType.getElementType(); + // Only 2D vectors are supported. if (sourceType.getRank() != 2) return rewriter.notifyMatchFailure(warpOp, "Only 2D reductions are supported."); ArrayRef reductionDims = reductionOp.getReductionDims(); - // Only 1 reduction dimension supported. + // Only 1 reduction dimension supported. This also ensures that result is + // also vector type. if (reductionDims.size() != 1) return rewriter.notifyMatchFailure( warpOp, "Only 1 reduction dimension is supported."); + int64_t reductionDim = reductionDims[0]; + auto resultType = cast(reductionOp.getType()); + auto distributedResultType = + cast(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + + // Currently we make the following assumptions. + // 1. The source vector is distributed in the column dimension. Each lane + // owns complete column(s) of the source vector. + // 2. If the reduction dim == 0, its a lane-local col reduction. In this + // case each lane owns its portion of the result (i.e. result is also + // distributed). + // 3. If reduction dim == 1, its a row reduction that require cross lanes + // shuffles. In this case result is not distributed and broadcasted instead. + // TODO: These assumptions are fairly restrictive. For example, source + // vector can have row distributed layout. Improve support for such cases. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + bool isResultDistributed = + distributedResultType.getNumElements() < resultType.getNumElements(); + if (reductionDim == 0 && !isResultDistributed) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting result vector to be distributed in a col reduction."); + if (reductionDim == 1 && isResultDistributed) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting result vector to be broadcasted in a row reduction."); // Create a constant vector to store the result of the reduction per lane. TypedAttr zeroAttr = @@ -2066,14 +2093,9 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { Value result = arith::ConstantOp::create( rewriter, reductionOp->getLoc(), distributedResultType, DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Col reduction. - if (reductionDims[0] == 0) { - // Source vector must be distributable to lanes in the col dimension. - if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) - return rewriter.notifyMatchFailure( - warpOp, "Source vector dimension must be divisible by warp size."); - // Compute source distributed type. + if (reductionDim == 0) { + // Compute source distributed type assuming each lane owns cols. SmallVector shape(sourceType.getShape()); shape[1] = shape[1] / warpOp.getWarpSize(); auto sourceDistributedType = VectorType::get(shape, elementType); @@ -2158,7 +2180,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( .add( + WarpOpInsertStridedSlice, WarpOpMultiReduction>( patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); From 07c0364d64109faf740023107ab68ec0f242d9ca Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 Aug 2025 23:36:26 +0000 Subject: [PATCH 08/10] save --- .../Vector/Transforms/VectorDistribute.cpp | 3 +- .../Vector/vector-warp-distribute.mlir | 32 ++++++++++--------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index ab0f1b55d04da..aecb6a11a7b36 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2034,8 +2034,7 @@ struct WarpOpReduction : public WarpDistributionPattern { // TODO: Add support for the case where source rows are distributed across // lanes. Requires `DistributionMapFn` to express the data distribution. struct WarpOpMultiReduction : public WarpDistributionPattern { - WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1) - : WarpDistributionPattern(context, benefit) {} + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *yieldOperand = diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bf0191655d654..95b8a48404f20 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -852,21 +852,23 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) { // ----- // CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce -// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { -// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> -// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> -// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> -// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 -// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> -// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 -// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> -// CHECK-PROP: return %[[R]] : vector<2xf32> +// CHECK-PROP : %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { +// CHECK-PROP : %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP : %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP : gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> +// CHECK-PROP : } +// CHECK-PROP : %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 +// CHECK-PROP-SAME : {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP : %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP : %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> +// CHECK-PROP : %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 +// CHECK-PROP : %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 +// CHECK-PROP-SAME : {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP : %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP : %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> +// CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 +// CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> +// CHECK-PROP : return %[[R]] : vector<2xf32> func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { %0 = "some_def"() : () -> (vector<32x64xf32>) From 4ed74d89b7808c17bed035a906acf15d2a96c51f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 28 Aug 2025 22:21:49 +0000 Subject: [PATCH 09/10] save work --- .../Vector/Transforms/VectorDistribute.cpp | 58 ++++++++++++++++--- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 2cd743d1ee8e8..4a6ea07c1c236 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1031,7 +1031,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { return failure(); auto oldCastOp = operand->get().getDefiningOp(); - unsigned int operandNumber = operand->getOperandNumber(); + unsigned operandNumber = operand->getOperandNumber(); VectorType sourceType = oldCastOp.getSourceVectorType(); VectorType distributedResultType = cast(warpOp->getResultTypes()[operandNumber]); @@ -2069,12 +2069,56 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; -// This patterns distribute the `vector.multi_reduction` operation across -// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes -// that source vector is distributed in column dimension (i.e. Each lane owns -// complete column(s) of the source vector). -// TODO: Add support for the case where source rows are distributed across -// lanes. Requires `DistributionMapFn` to express the data distribution. +/// This patterns distribute the `vector.multi_reduction` operation across +/// lanes in a warp. Currently only 2D to 1D reductions are supported and +/// assumes that source vector is distributed in column dimension (i.e. Each +/// lane owns complete column(s) of the source vector). +/// TODO: Add support for the case where source rows are distributed across +/// lanes. Requires `DistributionMapFn` to express the data distribution. +/// Example 1 (Col reduction): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() : () -> (vector<16x32xf32>) +/// %acc = "some_def"() : () -> (vector<32xf32>) +/// %1 = vector.multi_reduction , %0, %acc [0] : vector<16x32xf32> to +/// vector<32xf32> gpu.yield %1 : vector<32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>, +/// vector<1xf32>) { +/// %0 = "some_def"() : () -> (vector<16x32xf32>) +/// %acc = "some_def"() : () -> (vector<32xf32>) +/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32> +/// } +/// %c = arith.constant dense<0.0> : vector<1xf32> +/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32> +/// %2 = vector.reduction , %1, %r#1 : vector<16xf32> to f32 +/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32> +/// ``` +/// Example 2 (Row reduction): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { +/// %0 = "some_def"() : () -> (vector<2x32xf32>) +/// %acc = "some_def"() : () -> (vector<2xf32>) +/// %1 = vector.multi_reduction , %0, %acc [1] : vector<2x32xf32> to +/// vector<2xf32> +/// gpu.yield %1 : vector<2xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { +/// %0 = "some_def"() : () -> (vector<2x32xf32>) +/// %acc = "some_def"() : () -> (vector<2xf32>) +/// %1 = arith.constant dense<0.0> : vector<2xf32> +/// %2 = vector.extract %0[0] : vector<32xf32> from > +/// %3 = ("warp.reduction %2") : f32 +/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32> +/// ... repeat for row 1 +/// gpu.yield %1 : vector<2xf32> +/// } struct WarpOpMultiReduction : public WarpDistributionPattern { using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, From 116e4bceb48ebe85e35fd61dd9e52867897b0a39 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 2 Sep 2025 20:49:19 +0000 Subject: [PATCH 10/10] save work --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 4a6ea07c1c236..dddfcaf4f273d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" @@ -1039,7 +1040,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { bool isResultDistributed = distributedResultType.getNumElements() < oldCastOp.getResultVectorType().getNumElements(); - // If the result is not distributed, source distribted type is the same + // If the result is not distributed, source distributed type is the same // as the source type. If the result is distributed, we need to compute the // distributed source type according to following rules: // 1. If the source type is yielded from the warp op, we can use the @@ -1051,7 +1052,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { // Check if the source is yielded from the warp op. gpu::YieldOp yieldOp = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - auto *it = + OpOperand *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { return operand.get() == oldCastOp.getSource(); }); @@ -2155,7 +2156,9 @@ struct WarpOpMultiReduction : public WarpDistributionPattern { // case each lane owns its portion of the result (i.e. result is also // distributed). // 3. If reduction dim == 1, its a row reduction that require cross lanes - // shuffles. In this case result is not distributed and broadcasted instead. + // shuffles. In this case, the reduction result is not distributed across + // lanes. Instead each lane owns a complete copy of the result + // (broadcasted). // TODO: These assumptions are fairly restrictive. For example, source // vector can have row distributed layout. Improve support for such cases. if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)