Skip to content

Commit b087820

Browse files
committed
address comments
1 parent 3ae71df commit b087820

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,26 +1102,26 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
11021102
/// ```
11031103
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
11041104
/// ...
1105-
/// %src = ... : vector<4x16xf32>
1106-
/// %dest = ... : vector<8x16xf32>
1105+
/// %src = ... : vector<4x32xf32>
1106+
/// %dest = ... : vector<8x32xf32>
11071107
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1108-
/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
1109-
/// gpu.yield %insert : vector<8x16xf32>
1108+
/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1109+
/// gpu.yield %insert : vector<8x32xf32>
11101110
/// }
11111111
/// ```
11121112
/// To
11131113
/// ```
11141114
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
11151115
/// vector<8x1xf32>) {
11161116
/// ...
1117-
/// %src = ... : vector<4x16xf32>
1118-
/// %dest = ... : vector<8x16xf32>
1117+
/// %src = ... : vector<4x32xf32>
1118+
/// %dest = ... : vector<8x32xf32>
11191119
/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
11201120
/// }
11211121
/// %insert = vector.insert_strided_slice %0#0, %0#1,
11221122
/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
11231123
/// ```
1124-
/// NOTE: Current support assume that both src and dest vectors are distributed
1124+
/// NOTE: Current support assumes that both src and dest vectors are distributed
11251125
/// to lanes and sinking the insert op does not require any cross lane
11261126
/// communication.
11271127
struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
@@ -1159,7 +1159,8 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11591159
destDistributedDim - (destType.getRank() - srcType.getRank());
11601160
if (sourceDistributedDim < 0)
11611161
return rewriter.notifyMatchFailure(
1162-
insertOp, "distributed dimension must be in the last k dims");
1162+
insertOp,
1163+
"distributed dimension must be in the last k dims of dest vector");
11631164
// Distributed dimension must be fully inserted.
11641165
if (srcType.getDimSize(sourceDistributedDim) !=
11651166
destType.getDimSize(destDistributedDim))
@@ -1197,21 +1198,21 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11971198
/// ```
11981199
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
11991200
/// ...
1200-
/// %src = ... : vector<32x16xf32>
1201+
/// %src = ... : vector<64x32xf32>
12011202
/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1202-
/// strides = [1] : vector<32x16xf32> to vector<16x16xf32>
1203-
/// gpu.yield %extract : vector<16x16xf32>
1203+
/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1204+
/// gpu.yield %extract : vector<16x32xf32>
12041205
/// }
12051206
/// ```
12061207
/// To
1207-
/// ````
1208-
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
1208+
/// ```
1209+
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
12091210
/// ...
1210-
/// %src = ... : vector<32x16xf32>
1211-
/// gpu.yield %src : vector<32x16xf32>
1211+
/// %src = ... : vector<64x32xf32>
1212+
/// gpu.yield %src : vector<64x32xf32>
12121213
/// }
12131214
/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1214-
/// strides = [1] : vector<32x1xf32> to vector<16x1xf32>
1215+
/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
12151216
/// ```
12161217
/// NOTE: Current support assumes that the extraction happens only on non
12171218
/// distributed dimensions (does not require cross lane communication).

0 commit comments

Comments
 (0)