Skip to content

Commit f2e6ca8

Browse files
authored
[MLIR][Vector] Add warp distribution for vector.step op (llvm#155425)
This PR adds a distribution pattern for [`vector.step`](https://mlir.llvm.org/docs/Dialects/Vector/#vectorstep-vectorstepop) op. The result of the step op is a vector containing a sequence `[0,1,...,N-1]`. For the warp distribution, we consider a vector with `N == warp_size` (think SIMD). Distributing it to SIMT, means that each lane is represented by a thread/lane id scalar. More complex cases with the support for warp size multiples (e.g., `[0,1,...,2*N-1]`) require additional layout information to be handled properly. Such support may be added later. The lane id scalar is wrapped into a `vector<1xindex>` to emulate the sequence distribution result. Other than that, the distribution is similar to that of `arith.constant`.
1 parent 492089e commit f2e6ca8

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern {
705705
}
706706
};
707707

708+
/// Sink out step op feeding into a warp op yield.
709+
/// Vector step op is treated similar to arith.constant, apart from
710+
/// the result that represents a sequence [0, vec_size).
711+
/// Due to the to vec_size == warp_size limitation,
712+
/// we can simply wrap the lane id into a vector (i.e., broadcast).
713+
/// Supporting vec_size != warp_size may involve preserving the step
714+
/// result and using additional arith ops (the exact details are TBD).
715+
/// ```
716+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
717+
/// ...
718+
/// %cst = vector.step : vector<32xindex>
719+
/// gpu.yield %cst : vector<1xindex>
720+
/// }
721+
/// ```
722+
/// To
723+
/// ```
724+
/// gpu.warp_execute_on_lane_0(%arg0) {
725+
/// ...
726+
/// }
727+
/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
728+
struct WarpOpStep final : public WarpDistributionPattern {
729+
using Base::Base;
730+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
731+
PatternRewriter &rewriter) const override {
732+
OpOperand *yieldOperand =
733+
getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
734+
if (!yieldOperand)
735+
return failure();
736+
const unsigned operandIdx = yieldOperand->getOperandNumber();
737+
auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
738+
VectorType resTy = stepOp.getResult().getType();
739+
if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
740+
return rewriter.notifyMatchFailure(
741+
warpOp,
742+
llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
743+
resTy.getNumElements(), warpOp.getWarpSize()));
744+
VectorType newVecTy =
745+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
746+
rewriter.setInsertionPointAfter(warpOp);
747+
Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
748+
newVecTy, warpOp.getLaneid());
749+
rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
750+
return success();
751+
}
752+
};
753+
708754
/// Sink out transfer_read op feeding into a warp op yield.
709755
/// ```
710756
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -2016,7 +2062,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
20162062
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
20172063
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
20182064
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2019-
WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
2065+
WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
20202066
patterns.getContext(), benefit);
20212067
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
20222068
benefit);

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,3 +1824,35 @@ func.func @warp_propagate_duplicated_operands_in_yield(%laneid: index) {
18241824
// CHECK-PROP : }
18251825
// CHECK-PROP : %[T1:.*] = math.exp %[[W]] : vector<1xf32>
18261826
// CHECK-PROP : "some_use"(%[[T1]]) : (vector<1xf32>) -> ()
1827+
1828+
// -----
1829+
1830+
func.func @warp_step_distribute(%buffer: memref<128xindex>) {
1831+
%laneid = gpu.lane_id
1832+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xindex>) {
1833+
%seq = vector.step : vector<32xindex>
1834+
gpu.yield %seq : vector<32xindex>
1835+
}
1836+
vector.transfer_write %r, %buffer[%laneid] : vector<1xindex>, memref<128xindex>
1837+
return
1838+
}
1839+
1840+
// CHECK-PROP-LABEL: func.func @warp_step_distribute(
1841+
// CHECK-PROP: %[[LANE_ID:.*]] = gpu.lane_id
1842+
// CHECK-PROP: %[[LANE_ID_VEC:.*]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
1843+
// CHECK-PROP: vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
1844+
1845+
// -----
1846+
1847+
func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memref<128xindex>) {
1848+
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xindex>) {
1849+
%seq = vector.step : vector<64xindex>
1850+
gpu.yield %seq : vector<64xindex>
1851+
}
1852+
vector.transfer_write %r, %buffer[%laneid] : vector<2xindex>, memref<128xindex>
1853+
return
1854+
}
1855+
1856+
// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
1857+
// CHECK-PROP-NOT: vector.broadcast
1858+
// CHECK-PROP: vector.step : vector<64xindex>

0 commit comments

Comments
 (0)