@@ -1102,26 +1102,26 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
11021102// / ```
11031103// / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
11041104// / ...
1105- // / %src = ... : vector<4x16xf32 >
1106- // / %dest = ... : vector<8x16xf32 >
1105+ // / %src = ... : vector<4x32xf32 >
1106+ // / %dest = ... : vector<8x32xf32 >
11071107// / %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1108- // / strides = [1, 1] : vector<4x16xf32 > into vector<8x16xf32 >
1109- // / gpu.yield %insert : vector<8x16xf32 >
1108+ // / strides = [1, 1] : vector<4x32xf32 > into vector<8x32xf32 >
1109+ // / gpu.yield %insert : vector<8x32xf32 >
11101110// / }
11111111// / ```
11121112// / To
11131113// / ```
11141114// / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
11151115// / vector<8x1xf32>) {
11161116// / ...
1117- // / %src = ... : vector<4x16xf32 >
1118- // / %dest = ... : vector<8x16xf32 >
1117+ // / %src = ... : vector<4x32xf32 >
1118+ // / %dest = ... : vector<8x32xf32 >
11191119// / gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
11201120// / }
11211121// / %insert = vector.insert_strided_slice %0#0, %0#1,
11221122// / offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
11231123// / ```
1124- // / NOTE: Current support assume that both src and dest vectors are distributed
1124+ // / NOTE: Current support assumes that both src and dest vectors are distributed
11251125// / to lanes and sinking the insert op does not require any cross lane
11261126// / communication.
11271127struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
@@ -1159,7 +1159,8 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11591159 destDistributedDim - (destType.getRank () - srcType.getRank ());
11601160 if (sourceDistributedDim < 0 )
11611161 return rewriter.notifyMatchFailure (
1162- insertOp, " distributed dimension must be in the last k dims" );
1162+ insertOp,
1163+ " distributed dimension must be in the last k dims of dest vector" );
11631164 // Distributed dimension must be fully inserted.
11641165 if (srcType.getDimSize (sourceDistributedDim) !=
11651166 destType.getDimSize (destDistributedDim))
@@ -1197,21 +1198,21 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
11971198// / ```
11981199// / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
11991200// / ...
1200- // / %src = ... : vector<32x16xf32 >
1201+ // / %src = ... : vector<64x32xf32 >
12011202// / %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1202- // / strides = [1] : vector<32x16xf32 > to vector<16x16xf32 >
1203- // / gpu.yield %extract : vector<16x16xf32 >
1203+ // / strides = [1] : vector<64x32xf32 > to vector<16x32xf32 >
1204+ // / gpu.yield %extract : vector<16x32xf32 >
12041205// / }
12051206// / ```
12061207// / To
1207- // / ````
1208- // / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32 >) {
1208+ // / ```
1209+ // / %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32 >) {
12091210// / ...
1210- // / %src = ... : vector<32x16xf32 >
1211- // / gpu.yield %src : vector<32x16xf32 >
1211+ // / %src = ... : vector<64x32xf32 >
1212+ // / gpu.yield %src : vector<64x32xf32 >
12121213// / }
12131214// / %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1214- // / strides = [1] : vector<32x1xf32 > to vector<16x1xf32>
1215+ // / strides = [1] : vector<64x1xf32 > to vector<16x1xf32>
12151216// / ```
12161217// / NOTE: Current support assumes that the extraction happens only on non
12171218// / distributed dimensions (does not require cross lane communication).
0 commit comments