diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index bb0f339a26e43..ea9594fc68d31 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,13 +15,19 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include #include using namespace mlir; @@ -939,8 +945,40 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { } }; +static VectorType +tryFindDistributedType(TypedValue source, + WarpExecuteOnLane0Op warpOp, + const DistributionMapFn &distributionMapFn) { + VectorType distributedType = source.getType(); + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == source; + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedType = + cast(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(source); + VectorType computed = + getDistributedType(source.getType(), map, warpOp.getWarpSize()); + if (!computed) + return source.getType(); + distributedType = computed; + } + return distributedType; +} + struct WarpOpBroadcast : public WarpDistributionPattern { - using Base::Base; + WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -953,18 +991,23 @@ struct WarpOpBroadcast : public WarpDistributionPattern { auto destVecType = cast(warpOp->getResultTypes()[operandNumber]); Value broadcastSrc = broadcastOp.getSource(); - Type broadcastSrcType = broadcastSrc.getType(); + Type srcDistributedType = broadcastSrc.getType(); + + if (isa(srcDistributedType)) + srcDistributedType = + tryFindDistributedType(cast>(broadcastSrc), + warpOp, distributionMapFn); // Check that the broadcast actually spans a set of values uniformly across // all threads. In other words, check that each thread can reconstruct // their own broadcast. // For that we simply check that the broadcast we want to build makes sense. - if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != + if (vector::isBroadcastableTo(srcDistributedType, destVecType) != vector::BroadcastableToResult::Success) return failure(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); + rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = vector::BroadcastOp::create( rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); @@ -972,49 +1015,83 @@ struct WarpOpBroadcast : public WarpDistributionPattern { broadcasted); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. struct WarpOpShapeCast : public WarpDistributionPattern { - using Base::Base; + + WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); if (!operand) return failure(); - auto oldCastOp = operand->get().getDefiningOp(); unsigned int operandNumber = operand->getOperandNumber(); - auto castDistributedType = + VectorType sourceType = oldCastOp.getSourceVectorType(); + VectorType distributedResultType = cast(warpOp->getResultTypes()[operandNumber]); - VectorType castOriginalType = oldCastOp.getSourceVectorType(); - VectorType castResultType = castDistributedType; - - // We expect the distributed type to have a smaller rank than the original - // type. Prepend with size-one dimensions to make them the same. - unsigned castDistributedRank = castDistributedType.getRank(); - unsigned castOriginalRank = castOriginalType.getRank(); - if (castDistributedRank < castOriginalRank) { - SmallVector shape(castOriginalRank - castDistributedRank, 1); - llvm::append_range(shape, castDistributedType.getShape()); - castDistributedType = - VectorType::get(shape, castDistributedType.getElementType()); + VectorType distributedSourceType = sourceType; + bool isResultDistributed = distributedResultType.getNumElements() < + oldCastOp.getResultVectorType().getNumElements(); + + // If the result is not distributed, source distribted type is the same + // as the source type. If the result is distributed, we need to compute the + // distributed source type according to following rules: + // 1. If the source type is yielded from the warp op, we can use the + // matching warp result type as the distributed source type. + // 2. If the source type is not yielded from the warp op, we need + // to compute the distributed source type based on the distribution map + // and the warp size. + if (isResultDistributed) { + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = + llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == oldCastOp.getSource(); + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedSourceType = + cast(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(oldCastOp.getSource()); + distributedSourceType = + getDistributedType(sourceType, map, warpOp.getWarpSize()); + if (!distributedSourceType) + return rewriter.notifyMatchFailure( + oldCastOp, + "cannot compute distributed source type for shape cast"); + } } SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, + rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value newCast = vector::ShapeCastOp::create( - rewriter, oldCastOp.getLoc(), castResultType, + rewriter, oldCastOp.getLoc(), distributedResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Sink out vector.create_mask op feeding into a warp op yield. @@ -1995,6 +2072,114 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; +struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { + VectorMultiDimReductionDistribution(MLIRContext *context, + PatternBenefit benefit = 1) + : WarpDistributionPattern(context, benefit) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred); + if (!yieldOperand) + return failure(); + auto reductionOp = + cast(yieldOperand->get().getDefiningOp()); + unsigned operandNumber = yieldOperand->getOperandNumber(); + VectorType sourceType = reductionOp.getSourceVectorType(); + VectorType distributedResultType = + cast(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + // Only 2D vectors are supported. + if (sourceType.getRank() != 2) + return rewriter.notifyMatchFailure(warpOp, + "Only 2D reductions are supported."); + ArrayRef reductionDims = reductionOp.getReductionDims(); + // Only 1 reduction dimension supported. + if (reductionDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Only 1 reduction dimension is supported."); + + // Col reduction. + if (reductionDims[0] == 0) { + // Yield the source vector and the accumulator. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + SmallVector shape(sourceType.getShape()); + shape[1] = shape[1] / warpOp.getWarpSize(); + auto sourceDistributedType = VectorType::get(shape, elementType); + SmallVector newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, + {sourceDistributedType, distributedResultType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + // Create new reduction op. + // auto newOp = vector::MultiDimReductionOp::create( + // rewriter, reductionOp.getLoc(), distributedResultType, + // reductionOp.getKind(), + // /** source = **/ newWarpOp.getResult(newRetIndices[0]), + // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]), + // reductionDims); + // Create a constant zero value for storing the reduction result. + // rewriter.setInsertionPointAfter(reductionOp); + auto zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + int nCols = sourceDistributedType.getShape()[1]; + Value source = newWarpOp.getResult(newRetIndices[0]); + Value acc = newWarpOp.getResult(newRetIndices[1]); + for (int i = 0; i < nCols; ++i) { + Value col = vector::ExtractStridedSliceOp::create( + rewriter, reductionOp.getLoc(), source, {0, i}, + {sourceDistributedType.getShape()[0], 1}, {1, 1}); + col = vector::ShapeCastOp::create( + rewriter, reductionOp.getLoc(), + VectorType::get({sourceDistributedType.getShape()[0]}, elementType), + col); + Value accCol = + vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); + Value colReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); + // Insert the reduced column into the result. + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + colReduce, result, i); + } + // Replace the warp op result with the new reduction op. + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); + return success(); + } + // Row reduction. + // Create a constant zero value for storing the reduction result. + rewriter.setInsertionPointAfter(reductionOp); + auto zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + // Value result = arith::ConstantOp::create( + // rewriter, reductionOp.getLoc(), + // rewriter.getIntegerAttr(reductionOp.getType(), 0)); + int nRows = sourceType.getShape()[0]; + // For each row, do a vector reduction. + for (int i = 0; i < nRows; ++i) { + Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getSource(), i); + Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getAcc(), i); + Value rowReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + rowReduce, result, i); + } + // Replace the warp op result with the final result. + rewriter.replaceAllUsesWith(reductionOp.getResult(), result); + + return success(); + } +}; + } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( @@ -2015,16 +2200,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns - .add( - patterns.getContext(), benefit); + patterns.add( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add(patterns.getContext(), distributionMapFn, - benefit); + patterns.add( + patterns.getContext(), distributionMapFn, benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index bef88042fc663..10c2759493477 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -62,10 +62,17 @@ struct Layout { SmallVector layout; Layout() = default; Layout(std::initializer_list list) : layout(list) {} + Layout(SmallVector &list) : layout(list) {} void print(llvm::raw_ostream &os) const; size_t size() const { return layout.size(); } + int64_t operator[](size_t idx) const; }; +int64_t Layout::operator[](size_t idx) const { + assert(idx < layout.size() && "Index out of bounds"); + return layout[idx]; +} + void Layout::print(llvm::raw_ostream &os) const { os << llvm::interleaved_array(layout); } @@ -324,6 +331,13 @@ class LayoutInfoPropagation ArrayRef operands, ArrayRef results); + void visitVectorBroadCastOp(vector::BroadcastOp broadcast, + ArrayRef operands, + ArrayRef results); + void visitShapeCastOp(vector::ShapeCastOp shapeCast, + ArrayRef operands, + ArrayRef results); + public: LayoutInfoPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable) @@ -383,6 +397,12 @@ LogicalResult LayoutInfoPropagation::visitOperation( .Case([&](auto reductionOp) { visitVectorMultiReductionOp(reductionOp, operands, results); }) + .Case([&](auto broadcastOp) { + visitVectorBroadCastOp(broadcastOp, operands, results); + }) + .Case([&](auto shapeCastOp) { + visitShapeCastOp(shapeCastOp, operands, results); + }) // All other ops. .Default([&](Operation *op) { for (const LayoutInfoLattice *resultInfo : results) { @@ -437,6 +457,83 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp( propagateIfChanged(operands[1], operands[1]->meet(resultLayout)); } +void LayoutInfoPropagation::visitVectorBroadCastOp( + vector::BroadcastOp broadcast, ArrayRef operands, + ArrayRef results) { + // The layout of the result must be present. + LayoutInfo resultLayout = results[0]->getValue(); + if (!resultLayout.isAssigned()) + return; + // Only consider 1D -> 2D broadcasts or 2D -> 2D broadcasts. + VectorType resultTy = broadcast.getResultVectorType(); + VectorType sourceTy = dyn_cast(broadcast.getSourceType()); + if (!sourceTy) { + broadcast.emitWarning("Expecting source type to be a vector type."); + return; + } + + // Only conside 2D -> 2D broadcast. + if (sourceTy.getRank() != 2 || resultTy.getRank() != 2) { + broadcast.emitWarning("Expecting source type to be 2D vector and " + "result type to be 2D vector."); + return; + } + SetVector broadcastUnitDims = broadcast.computeBroadcastedUnitDims(); + if (broadcastUnitDims.size() != 1) { + broadcast.emitWarning("Expecting source type to be 2D vector only with " + "one broadcasted dimension."); + return; + } + // Propagate the result layout to the source operand. + propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); +} + +void LayoutInfoPropagation::visitShapeCastOp( + vector::ShapeCastOp shapeCast, ArrayRef operands, + ArrayRef results) { + // The layout of the result must be present. + LayoutInfo resultLayout = results[0]->getValue(); + if (!resultLayout.isAssigned()) + return; + VectorType sourceTy = shapeCast.getSourceVectorType(); + VectorType resultTy = shapeCast.getResultVectorType(); + // Expecting source rank to be 1D or 2D. + if (sourceTy.getRank() != 1 && sourceTy.getRank() != 2) { + shapeCast.emitWarning("Expecting source type to be 1D or 2D vector."); + return; + } + // Expecting result rank to be 1D or 2D. + if (resultTy.getRank() != 1 && resultTy.getRank() != 2) { + shapeCast.emitWarning("Expecting result type to be 1D or 2D vector."); + return; + } + // For 2D -> 2D shape cast, propagate the result layout to the source. + if (sourceTy.getRank() == 2 && resultTy.getRank() == 2) { + // Propagate the result layout to the source operand. + propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); + return; + } + auto resultLayoutArray = resultLayout.getLayoutAsArrayRef(); + if (resultLayoutArray[0] != 1 && resultLayoutArray[1] != 1) { + shapeCast.emitWarning( + "Expecting result layout to be of form [1, subgroupSize] " + "or [subgroupSize, 1]."); + return; + } + int64_t distributedDim = resultLayoutArray[0] == 1 ? 1 : 0; + // If the result shape can be evenly distributed in the distributed dimension, + // then the source layout should be [subgroupSize][1]. Otherwise, data is + // shared accross lanes (broadcasted). In that case, just assign [1][1] for + // now (TODO: Use slice for this case) + LayoutInfo sourceLayout = + resultTy.getShape()[distributedDim] % xegpu::targetinfo::subgroupSize == 0 + ? LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}), + LaneData({1})) + : LayoutInfo(LaneLayout({1}), LaneData({1})); + // Propagate the source layout to the source operand. + propagateIfChanged(operands[0], operands[0]->meet(sourceLayout)); +} + /// Propagate the layout of the result tensor to the source tensor descriptor in /// UpdateNdOffsetOp. void LayoutInfoPropagation::visitUpdateNdOffsetOp( @@ -529,16 +626,64 @@ void LayoutInfoPropagation::visitVectorBitcastOp( bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth(); int outElemTyBitWidth = bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth(); - - // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit - // a warning and return. - if (inElemTyBitWidth != outElemTyBitWidth) { - bitcast.emitWarning("Widening or narrowing bitcasts are not expected at " - "layout propagation stage."); + // If the element bit widths are the same, then the layout does not change. + if (inElemTyBitWidth == outElemTyBitWidth) { + propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); return; } + int64_t rank = bitcast.getSourceVectorType().getRank(); + // Bitcast is a `narrowing` if the input element type bit width larger than + // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast. + bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth; + int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth + : outElemTyBitWidth / inElemTyBitWidth; + const LaneLayout &sourceLaneLayout = + resultLayout.getLayout(); // source lane layout is unchanged. + ArrayRef currData = resultLayout.getDataAsArrayRef(); + + // TODO: Currently we assume that bitcasts does not require cross lane + // communication. So each lane must own the required number of elements to + // perform the bitcast locally without cross-lane communication. + // For 1D vectors, decide how many elements each lane owns based on whether + // the bitcast is narrowing or widening. + if (rank == 1) { + if ((currData[0] * outElemTyBitWidth) % inElemTyBitWidth != 0) { + bitcast.emitWarning( + "Narrowing bitcast with cross lane communication is not supported."); + return; + } + LaneData sourceLaneData = isNarrowing + ? LaneData({currData[0] / bitCastRatio}) + : LaneData({currData[0] * bitCastRatio}); - propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo( + sourceLaneLayout, sourceLaneData))); + } + // For nD vectors, Each lane is not allowed to own multiple elements in any + // dimension other than the innermost dimension. + // TODO: Add support for other case depending on the use case. + SmallVector sourceLaneDataStorage(currData.begin(), + currData.end() - 1); + if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) { + bitcast.emitWarning( + "Each lane must not own multiple elements in any dimension other than " + "the innermost dimension."); + return; + } + // Check if the bitcast requires cross lane communication. + if ((currData[rank - 1] * outElemTyBitWidth) % inElemTyBitWidth != 0) { + bitcast.emitWarning( + "Narrowing bitcast with cross lane communication is not supported."); + return; + } + // Decide lane data based on whether the bitcast is narrowing or widening. + int64_t innerMostLaneData = isNarrowing ? currData[rank - 1] / bitCastRatio + : currData[rank - 1] * bitCastRatio; + sourceLaneDataStorage.push_back(innerMostLaneData); + LaneData sourceLaneData(sourceLaneDataStorage); + + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo( + sourceLaneLayout, sourceLaneData))); } /// Propagate the layout of the result to the tensor descriptor and mask diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 2088c3c7fc5ec..61eece55a9bac 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -34,6 +35,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/LogicalResult.h" +#include namespace mlir { namespace xegpu { @@ -146,6 +150,15 @@ static bool hasPackedLayout(xegpu::LayoutAttr layout) { return laneData.asArrayRef()[0] != 1; } +static bool hasTransposedLayout(xegpu::LayoutAttr layout) { + if (layout == xegpu::LayoutAttr()) + return false; + DenseI32ArrayAttr laneLayout = layout.getLaneLayout(); + if (!laneLayout || laneLayout.size() != 2) + return false; + return laneLayout.asArrayRef()[0] > 1 && laneLayout.asArrayRef()[1] == 1; +} + /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is /// contained within a WarpExecuteOnLane0Op. @@ -500,6 +513,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { xegpu::removeLayoutAttrs(newLoadOp); // Set the packed attribute if the layout requires it. newLoadOp.setPacked(hasPackedLayout(layout)); + if (hasTransposedLayout(layout)) + newLoadOp.setTranspose( + DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0})); Value distributedVal = newWarpOp.getResult(operandIdx); // There can be a conflict between the vector type distributed by the // warp op and (xegpu-specific) distributed type supported by the load @@ -811,6 +827,135 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { } }; +struct MemrefExtractAlignedPointerAsIndexDistribution final + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult( + warpOp, llvm::IsaPred); + if (!operand) + return rewriter.notifyMatchFailure( + warpOp, + "warp result is not a xegpu::MemrefExtractAlignedPointerAsIndex op"); + auto extractOp = + operand->get().getDefiningOp(); + unsigned operandIdx = operand->getOperandNumber(); + SmallVector newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, extractOp.getSource(), + TypeRange{extractOp.getSource().getType()}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, newWarpOp.getLoc(), extractOp.getType(), + newWarpOp.getResult(newRetIndices[0])); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult()); + return success(); + } +}; + +struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred); + if (!operand) + return rewriter.notifyMatchFailure( + warpOp, "warp result is not a vector::BitCast op"); + auto bitcastOp = operand->get().getDefiningOp(); + unsigned operandIdx = operand->getOperandNumber(); + VectorType distributedSourceType = + getDistVecTypeBasedOnLaneLayout( + xegpu::getLayoutAttr(bitcastOp.getSource()), + bitcastOp.getSourceVectorType()) + .value_or(VectorType()); + if (!distributedSourceType) + return rewriter.notifyMatchFailure( + bitcastOp, "Failed to distribute the source vector type in " + "vector::BitCast op"); + VectorType distributedResultType = + cast(warpOp.getResult(operandIdx).getType()); + if (distributedSourceType.getRank() != 2 || + distributedResultType.getRank() != 2) + return rewriter.notifyMatchFailure( + bitcastOp, "the source or result vector of the bitcast op " + "are not 2D vectors"); + SmallVector newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, bitcastOp.getSource(), + TypeRange{distributedSourceType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + auto newBitcastOp = vector::BitCastOp::create( + rewriter, newWarpOp.getLoc(), distributedResultType, + newWarpOp.getResult(newRetIndices[0])); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult()); + return success(); + } +}; + +struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred); + if (!operand) + return rewriter.notifyMatchFailure( + warpOp, "warp result is not a vector::Transpose op"); + auto transposeOp = operand->get().getDefiningOp(); + unsigned operandIdx = operand->getOperandNumber(); + xegpu::LayoutAttr sourceLayout = + xegpu::getLayoutAttr(transposeOp.getVector()); + xegpu::LayoutAttr resultLayout = + xegpu::getLayoutAttr(transposeOp.getResult()); + if (!sourceLayout || !resultLayout) + return rewriter.notifyMatchFailure( + transposeOp, + "the source or result vector of the transpose op lacks layout " + "attribute"); + ArrayRef sourceLaneLayout = sourceLayout.getLaneLayout().asArrayRef(); + ArrayRef resultLaneLayout = resultLayout.getLaneLayout().asArrayRef(); + ArrayRef sourceLaneData = sourceLayout.getLaneData().asArrayRef(); + ArrayRef resultLaneData = resultLayout.getLaneData().asArrayRef(); + if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2) + return rewriter.notifyMatchFailure( + transposeOp, "the source or result vector of the transpose op " + "does not have 2D layout"); + auto is2DTranspose = [](ArrayRef input, ArrayRef output) { + return input.size() == 2 && output.size() == 2 && input[0] == output[1] && + input[1] == output[0]; + }; + + if (!is2DTranspose(sourceLaneLayout, resultLaneLayout) || + !is2DTranspose(sourceLaneData, resultLaneData)) + return rewriter.notifyMatchFailure( + transposeOp, + "the source or result vector layouts must be transposes of each " + "other"); + FailureOr distributedSourceTypeOrFailure = + getDistVecTypeBasedOnLaneLayout(sourceLayout, + transposeOp.getSourceVectorType()); + if (failed(distributedSourceTypeOrFailure)) + return rewriter.notifyMatchFailure( + transposeOp, "Failed to distribute the source vector type in " + "vector::Transpose op"); + SmallVector newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, transposeOp.getVector(), + TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + auto newTransposeOp = vector::TransposeOp::create( + rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]), + transposeOp.getPermutation()); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult()); + return success(); + } +}; + } // namespace namespace { @@ -825,7 +970,9 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( RewritePatternSet &patterns) { patterns.add( + UpdateNdOffsetDistribution, GpuBarrierDistribution, + VectorTransposeDistribution, VectorBitcastDistribution, + MemrefExtractAlignedPointerAsIndexDistribution>( patterns.getContext()); } @@ -903,14 +1050,47 @@ void XeGPUSubgroupDistributePass::runOnOperation() { int64_t warpSz) { return Value(); }; vector::populatePropagateWarpVectorDistributionPatterns( patterns, distributionFn, shuffleFn); + + auto warpReduction = [](Location loc, OpBuilder &builder, Value input, + vector::CombiningKind kind, uint32_t size) { + // First reduce on a single thread to get per lane reduction value. + Value laneVal = builder.create(loc, kind, input); + // Parallel reduction using butterfly shuffles. + for (uint64_t i = 1; i < size; i <<= 1) { + Value shuffled = + builder + .create(loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); + laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); + } + return laneVal; + }; + + vector::populateDistributeReduction(patterns, warpReduction); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); return; } - // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted + // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted // due to tensor desc type mismatches created by using upstream distribution - // patterns (scf.for) + // patterns (scf.for). This cleanup should only be done if all the ops are + // distributed successfully, if some ops are still not distributed and remains + // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid + // breaking the IR. + bool foundWarpOp = false; + getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) { + // Look for WarpOps that are not trivially dead. + if (isOpTriviallyDead(warpOp)) + return WalkResult::advance(); + foundWarpOp = true; + return WalkResult::interrupt(); + }); + if (foundWarpOp) + return; + getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) { // We are only interested in UnrealizedConversionCastOps there were added // for resolving SIMT type mismatches. @@ -929,7 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { "Unrealized conversion cast must have tensor descriptor types"); // tensor_desc -> tensor_desc Type of conversions. - // This occurs iside scf.for body to resolve the block argument type to + // This occurs inside scf.for body to resolve the block argument type to // SIMT type. if (inputDescType.getLayout()) { auto argument = mlir::dyn_cast(input); diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 0214d84f2c16f..4cbe4db271ad6 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -181,6 +181,23 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1 return } +// ----- +// CHECK-LABEL: func.func @vector_bitcast_i32_to_f16( +// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x8xi32> -> vector<16x8xi32> + %4 = vector.bitcast %3 : vector<16x8xi32> to vector<16x16xf16> + %5 = vector.transpose %4, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %6 = xegpu.dpas %2, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %6, %7 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + // ----- // CHECK-LABEL: func.func @binary_op_one_use( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout>,