Skip to content

Commit ba6721e

Browse files
authored
improve reduction logic (#840)
1 parent 4dcf6ee commit ba6721e

File tree

2 files changed

+76
-11
lines changed

2 files changed

+76
-11
lines changed

lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
169169

170170
// Stage 1: vector<ixjx1xnxf16> equals to a grid of ixj of vector<1xnxf16>
171171
// after lowering to xegpu. This stage performs j-1 reduction operations on
172-
// j dim of the grid, the result is a vector of vector<mxnxf16> with size i.
172+
// j dim of the grid, the result is a vector of vector<ixnxf16>.
173173
llvm::SmallVector<mlir::Value> intermediates(shape[0]);
174174
for (auto i = 0; i < shape[0]; i++) {
175175
auto combiningVal = sources[i * shape[1]];
@@ -191,14 +191,13 @@ llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
191191
// v2 = [b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ... b31]
192192
// ...
193193
// vn = [p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 ... p31]
194-
// it will repeately doing shuffle between two consecutive vectors
195-
// v1 and v2, v3 and v4, ..., vn-1 and vn with a block size. Such
196-
// that we can get two new vectors. The block size is typically
197-
// starts with half of the vector size. For example, for v1 and v2,
198-
// it is 16, and we can get:
194+
// To reduce it, we repeatedly shuffle halves of two consecutive vectors.
195+
// One can view it as: transpose halves of two partial aggregates, reduce
196+
// vertically, get 1 vector with reduced halves of two vectors. For example,
197+
// for v1 and v2, we get:
199198
// nv1 = [a0, .., a15, b0, .., b15]
200199
// nv2 = [a16, .., a31, b16, .., b31]
201-
// and we then performs nv1 + nv2 (if reduction op is add)
200+
// nv_reduced = reductionOp(nv1,nv2)
202201
// such that the left half of the vector contains the partial reduction
203202
// of v1, and the right half contains the partial reduction of v2.
204203
// and the the number of vectors is reduced by half after one iteration.
@@ -207,12 +206,16 @@ llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
207206
// The intermediate result of this stage is an array of vectors with
208207
// type, e.g., vector<nxf16>, array size is `i/n`. And these vectors
209208
// will be merged into a single vector with type vector<ixf16>.
210-
auto blkSize = shape[3] / 2;
211-
while (blkSize) {
209+
210+
// each row should not have > 1 partial aggregate at the end
211+
auto partialRowAggSize{shape[3]};
212+
auto numVecsLeft{shape[0]};
213+
while (partialRowAggSize != 1 && numVecsLeft != 1) {
214+
partialRowAggSize /= 2;
212215
auto workList = intermediates;
213216
intermediates.clear();
214217
assert(workList.size() % 2 == 0 && "The size should be divisible by 2.");
215-
auto masks = genShuffleMasks(blkSize, shape[3]);
218+
auto masks = genShuffleMasks(partialRowAggSize, shape[3]);
216219
for (size_t i = 0; i < workList.size(); i += 2) {
217220
auto v1 = workList[i];
218221
auto v2 = workList[i + 1];
@@ -224,7 +227,28 @@ llvm::SmallVector<mlir::Value> lowerInnerReductionWithIntraVectorShuffles(
224227
createBinOp(kind, shuffleOp1, shuffleOp2, elemTy, loc, rewriter);
225228
intermediates.push_back(reductionVal);
226229
}
227-
blkSize /= 2;
230+
numVecsLeft /= 2;
231+
}
232+
233+
if (partialRowAggSize > 1) {
234+
assert(intermediates.size() == 1 &&
235+
"We must have ONE row with non-finalized aggregates.");
236+
auto toFinalize = intermediates.back();
237+
intermediates.clear();
238+
uint32_t currentAggVecSize = shape[3];
239+
do {
240+
currentAggVecSize /= 2;
241+
partialRowAggSize /= 2;
242+
auto [vecUpperMask, vecLowerMask] =
243+
genShuffleMasks(partialRowAggSize, currentAggVecSize);
244+
auto shuffleOp1 = rewriter.create<mlir::vector::ShuffleOp>(
245+
loc, toFinalize, toFinalize, vecUpperMask);
246+
auto shuffleOp2 = rewriter.create<mlir::vector::ShuffleOp>(
247+
loc, toFinalize, toFinalize, vecLowerMask);
248+
toFinalize =
249+
createBinOp(kind, shuffleOp1, shuffleOp2, elemTy, loc, rewriter);
250+
} while (partialRowAggSize != 1);
251+
intermediates.push_back(toFinalize);
228252
}
229253
return intermediates;
230254
}

test/Conversion/XeTileToXeGPU/reduction.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,47 @@ module {
186186
gpu.return
187187
}
188188

189+
gpu.func @inner_reduction_1(%a: memref<8x32xf32>, %b: memref<8x1xf32>) {
190+
%c0 = arith.constant 0 : index
191+
%neg_inf = arith.constant dense<0xFF800000> : vector<8xf32> // -inf
192+
193+
%a_tile = xetile.init_tile %a[%c0, %c0] : memref<8x32xf32> -> !xetile.tile<8x32xf32>
194+
%b_tile = xetile.init_tile %b[%c0, %c0] : memref<8x1xf32> -> !xetile.tile<8x1xf32>
195+
196+
//CHECK: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<8x16xf32>
197+
//CHECK: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<8x16xf32>
198+
%a_loaded = xetile.load_tile %a_tile: !xetile.tile<8x32xf32> -> vector<8x32xf32>
199+
200+
//CHECK: %[[R1:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
201+
//CHECK: %[[R2:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
202+
//CHECK: %[[R3:.*]] = arith.maximumf %[[R1]], %[[R2]] : vector<16xf32>
203+
//CHECK: %[[R4:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
204+
//CHECK: %[[R5:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
205+
//CHECK: %[[R6:.*]] = arith.maximumf %[[R4]], %[[R5]] : vector<16xf32>
206+
//CHECK: %[[R7:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
207+
//CHECK: %[[R8:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
208+
//CHECK: %[[R9:.*]] = arith.maximumf %[[R7]], %[[R8]] : vector<16xf32>
209+
//CHECK: %[[R10:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
210+
//CHECK: %[[R11:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
211+
//CHECK: %[[R12:.*]] = arith.maximumf %[[R10]], %[[R11]] : vector<16xf32>
212+
//CHECK: %[[R13:.*]] = vector.shuffle %[[R3]], %[[R6]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
213+
//CHECK: %[[R14:.*]] = vector.shuffle %[[R3]], %[[R6]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
214+
//CHECK: %[[R15:.*]] = arith.maximumf %[[R13]], %[[R14]] : vector<16xf32>
215+
//CHECK: %[[R16:.*]] = vector.shuffle %[[R9]], %[[R12]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
216+
//CHECK: %[[R17:.*]] = vector.shuffle %[[R9]], %[[R12]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
217+
//CHECK: %[[R18:.*]] = arith.maximumf %[[R16]], %[[R17]] : vector<16xf32>
218+
//CHECK: %[[R19:.*]] = vector.shuffle %[[R15]], %[[R18]] [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf32>, vector<16xf32>
219+
//CHECK: %[[R20:.*]] = vector.shuffle %[[R15]], %[[R18]] [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf32>, vector<16xf32>
220+
//CHECK: %[[R21:.*]] = arith.maximumf %[[R19]], %[[R20]] : vector<16xf32>
221+
//CHECK: %[[R22:.*]] = vector.shuffle %[[R21]], %[[R21]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xf32>, vector<16xf32>
222+
//CHECK: %[[R23:.*]] = vector.shuffle %[[R21]], %[[R21]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xf32>, vector<16xf32>
223+
//CHECK: %[[R24:.*]] = arith.maximumf %[[R22]], %[[R23]] : vector<8xf32>
224+
%3 = vector.multi_reduction <maximumf>, %a_loaded, %neg_inf [1] : vector<8x32xf32> to vector<8xf32> // fastmath<nnan> is implicit here
225+
%reduced = vector.shape_cast %3 : vector<8xf32> to vector<8x1xf32>
226+
xetile.store_tile %reduced, %b_tile : vector<8x1xf32>, !xetile.tile<8x1xf32>
227+
gpu.return
228+
}
229+
189230
//CHECK: gpu.func @outter_reduction(%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) {
190231
gpu.func @outter_reduction(%a: memref<128x256xf16>, %b: memref<128x256xf16>) {
191232
//CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<32xf16>

0 commit comments

Comments
 (0)