diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 045c192787f10..af90ed8f5deaf 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,9 +15,12 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" #include @@ -52,6 +55,25 @@ static AffineMap calculateImplicitMap(VectorType sequentialType, return map; } +/// Given a sequential and distributed vector type, returns the distributed +/// dimension. This function expects that only a single dimension is +/// distributed. +static int getDistributedDim(VectorType sequentialType, + VectorType distributedType) { + assert(sequentialType.getRank() == distributedType.getRank() && + "sequential and distributed vector types must have the same rank"); + int64_t distributedDim = -1; + for (int64_t i = 0; i < sequentialType.getRank(); ++i) { + if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) { + // Keep this assert here in case WarpExecuteOnLane0Op gets extended to + // support distributing multiple dimensions in the future. + assert(distributedDim == -1 && "found multiple distributed dims"); + distributedDim = i; + } + } + return distributedDim; +} + namespace { /// Helper struct to create the load / store operations that permit transit @@ -1076,6 +1098,196 @@ struct WarpOpCreateMask : public WarpDistributionPattern { } }; +/// Sink out insert_strided_slice op feeding into a warp op yield. +/// ``` +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) { +/// ... +/// %src = ... : vector<4x32xf32> +/// %dest = ... : vector<8x32xf32> +/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0], +/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32> +/// gpu.yield %insert : vector<8x32xf32> +/// } +/// ``` +/// To +/// ``` +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>, +/// vector<8x1xf32>) { +/// ... +/// %src = ... : vector<4x32xf32> +/// %dest = ... : vector<8x32xf32> +/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32> +/// } +/// %insert = vector.insert_strided_slice %0#0, %0#1, +/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32> +/// ``` +/// NOTE: Current support assumes that both src and dest vectors are distributed +/// to lanes and sinking the insert op does not require any cross lane +/// communication. +struct WarpOpInsertStridedSlice : public WarpDistributionPattern { + using Base::Base; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = + operand->get().getDefiningOp(); + auto distributedType = + cast(warpOp.getResult(operandNumber).getType()); + // Distributed type must be 2D or higher. + // TODO: Support 1D distributed types. + if (distributedType.getRank() < 2) + return rewriter.notifyMatchFailure( + insertOp, "result vector type must be 2D or higher"); + // Find the distributed dimension of the dest vector. There should be + // exactly one. + auto yieldedType = cast(operand->get().getType()); + int64_t destDistributedDim = + getDistributedDim(yieldedType, distributedType); + assert(destDistributedDim != -1 && "could not find distributed dimension"); + + VectorType srcType = insertOp.getSourceVectorType(); + VectorType destType = insertOp.getDestVectorType(); + // Currently we require that both source (kD) and dest (nD) vectors are + // distributed. This requires that distributedDim (d) is contained in the + // last k dims of the dest vector (d >= n - k). + // TODO: Add support for case where source vector is not distributed. + int64_t sourceDistributedDim = + destDistributedDim - (destType.getRank() - srcType.getRank()); + if (sourceDistributedDim < 0) + return rewriter.notifyMatchFailure( + insertOp, + "distributed dimension must be in the last k dims of dest vector"); + // Distributed dimension must be fully inserted. + if (srcType.getDimSize(sourceDistributedDim) != + destType.getDimSize(destDistributedDim)) + return rewriter.notifyMatchFailure( + insertOp, "distributed dimension must be fully inserted"); + SmallVector newSourceDistShape( + insertOp.getSourceVectorType().getShape()); + newSourceDistShape[sourceDistributedDim] = + distributedType.getDimSize(destDistributedDim); + auto newSourceTy = + VectorType::get(newSourceDistShape, distributedType.getElementType()); + VectorType newDestTy = distributedType; + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, + {newSourceTy, newDestTy}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value distributedSource = newWarpOp->getResult(newRetIndices[0]); + Value distributedDest = newWarpOp->getResult(newRetIndices[1]); + // Create a new insert strided slice op that inserts distributed source into + // distributed dest. + Value newInsert = rewriter.create( + insertOp.getLoc(), distributedDest.getType(), distributedSource, + distributedDest, insertOp.getOffsets(), insertOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert); + return success(); + } +}; + +/// Sink out extract_strided_slice op feeding into a warp op yield. +/// ``` +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) { +/// ... +/// %src = ... : vector<64x32xf32> +/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16], +/// strides = [1] : vector<64x32xf32> to vector<16x32xf32> +/// gpu.yield %extract : vector<16x32xf32> +/// } +/// ``` +/// To +/// ``` +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) { +/// ... +/// %src = ... : vector<64x32xf32> +/// gpu.yield %src : vector<64x32xf32> +/// } +/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16], +/// strides = [1] : vector<64x1xf32> to vector<16x1xf32> +/// ``` +/// NOTE: Current support assumes that the extraction happens only on non +/// distributed dimensions (does not require cross lane communication). +struct WarpOpExtractStridedSlice : public WarpDistributionPattern { + using Base::Base; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto extractOp = + operand->get().getDefiningOp(); + auto distributedType = + cast(warpOp.getResult(operandNumber).getType()); + // Distributed type must be 2D or higher. + // TODO: Support 1D distributed types. + if (distributedType.getRank() < 2) + return rewriter.notifyMatchFailure( + extractOp, "result vector type must be 2D or higher"); + + // Find the distributed dimension. There should be exactly one. + auto yieldedType = cast(operand->get().getType()); + int64_t distributedDim = getDistributedDim(yieldedType, distributedType); + assert(distributedDim != -1 && "could not find distributed dimension"); + + int64_t numOfExtractedDims = + static_cast(extractOp.getSizes().size()); + // If the distributed dim is included in the extracted dims, then we make + // sure distributed dim is fully extracted. If distributed dim is not + // included in extracted dims, it is guaranteed to be fully extracted (i.e. + // distributed dim comes after all the extracted dims) + // TODO: Partial extraction from distributed dimension require cross lane + // communication. + if (distributedDim < numOfExtractedDims) { + int64_t distributedDimOffset = + llvm::cast(extractOp.getOffsets()[distributedDim]) + .getInt(); + int64_t distributedDimSize = + llvm::cast(extractOp.getSizes()[distributedDim]) + .getInt(); + if (distributedDimOffset != 0 || + distributedDimSize != yieldedType.getDimSize(distributedDim)) + return rewriter.notifyMatchFailure( + extractOp, "distributed dimension must be fully extracted"); + } + SmallVector newDistributedShape( + extractOp.getSourceVectorType().getShape()); + newDistributedShape[distributedDim] = + distributedType.getDimSize(distributedDim); + auto newDistributedType = + VectorType::get(newDistributedShape, distributedType.getElementType()); + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + SmallVector distributedSizes = llvm::map_to_vector( + extractOp.getSizes(), [](Attribute attr) { return attr; }); + // Update the distributed sizes to match the distributed type. + if (distributedDim < static_cast(distributedSizes.size())) + distributedSizes[distributedDim] = rewriter.getI64IntegerAttr( + distributedType.getDimSize(distributedDim)); + + // Create a new extract strided slice op that extracts from the + // distributed vector. + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); + Value newExtract = rewriter.create( + extractOp.getLoc(), distributedType, distributedVec, + extractOp.getOffsets(), + ArrayAttr::get(rewriter.getContext(), distributedSizes), + extractOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newExtract); + return success(); + } +}; + /// Pattern to move out vector.extract of single element vector. Those don't /// need to be distributed and can just be propagated outside of the region. struct WarpOpExtract : public WarpDistributionPattern { @@ -1122,15 +1334,7 @@ struct WarpOpExtract : public WarpDistributionPattern { auto distributedType = cast(warpOp.getResult(operandNumber).getType()); auto yieldedType = cast(operand->get().getType()); - int64_t distributedDim = -1; - for (int64_t i = 0; i < yieldedType.getRank(); ++i) { - if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { - // Keep this assert here in case WarpExecuteOnLane0Op gets extended to - // support distributing multiple dimensions in the future. - assert(distributedDim == -1 && "found multiple distributed dims"); - distributedDim = i; - } - } + int64_t distributedDim = getDistributedDim(yieldedType, distributedType); assert(distributedDim != -1 && "could not find distributed dimension"); (void)distributedDim; @@ -1764,7 +1968,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( patterns.add( + WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask, + WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>( patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 38771f2593449..7cfbcdf101d11 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) { return %r : vector<4x96xf32> } +// ----- +// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner( +// CHECK-RPOP-SAME: %[[LANEID:.*]]: index +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) { +// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32> +// CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32> +// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]] +// CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32> +// CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32> +func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<24x1xf32>) { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) { + %0 = "some_def"() : () -> (vector<64x32xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]} + : vector<64x32xf32> to vector<24x32xf32> + gpu.yield %1 : vector<24x32xf32> + } + return %r : vector<24x1xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) { +// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32> +// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]] +// CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32> +// CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32> +func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<1x8xf32>) { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) { + %0 = "some_def"() : () -> (vector<32x64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]} + : vector<32x64xf32> to vector<32x8xf32> + gpu.yield %1 : vector<32x8xf32> + } + return %r : vector<1x8xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index) +// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32> +// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32> +// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 +// CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32> +// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32> +func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) { + %0 = "some_def"() : () -> (vector<32xf32>) + %1 = "some_def"() : () -> (vector<64x32xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]} + : vector<32xf32> into vector<64x32xf32> + gpu.yield %2 : vector<64x32xf32> + } + return %r : vector<64x1xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index) +// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32> +// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32> +// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} : +// CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32> +// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32> +func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) { + %0 = "some_def"() : () -> (vector<16x32xf32>) + %1 = "some_def"() : () -> (vector<64x32xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0], strides = [1, 1]} + : vector<16x32xf32> into vector<64x32xf32> + gpu.yield %2 : vector<64x32xf32> + } + return %r : vector<64x1xf32> +} + // ----- // Make sure that all operands of the transfer_read op are properly propagated.