@@ -225,12 +225,96 @@ static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
225
225
return toVal;
226
226
}
227
227
228
+ // Permute lanes of the input val and apply reduction to permuted values.
229
+ static Value permuteAndReduce (RewriterBase &rewriter, Location loc,
230
+ StringRef intrinsic, Value val,
231
+ Operation *reduxOp) {
232
+ Type valType = val.getType ();
233
+ assert (valType.getIntOrFloatBitWidth () <= 32 );
234
+
235
+ Type actualType = valType;
236
+ if (!valType.isInteger (32 ))
237
+ actualType = castToAndSExtInt (rewriter, loc, val, valType, 32 );
238
+
239
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
240
+ Value falseVal = b.false_val ();
241
+ MLIRContext *ctx = rewriter.getContext ();
242
+ Type retType = struct_ty ({i32_ty, i32_ty});
243
+ Value perm =
244
+ LLVM::createLLVMIntrinsicCallOp (rewriter, loc, intrinsic, retType,
245
+ ValueRange{val, val, falseVal, falseVal})
246
+ ->getResult (0 );
247
+ Value v0 = b.extract_val (i32_ty, perm, 0 );
248
+ Value v1 = b.extract_val (i32_ty, perm, 1 );
249
+
250
+ if (!valType.isInteger (32 )) {
251
+ v0 = truncAndCastFromInt (rewriter, loc, v0, valType, 32 );
252
+ v1 = truncAndCastFromInt (rewriter, loc, v1, valType, 32 );
253
+ }
254
+ IRMapping mapping;
255
+ mapping.map (reduxOp->getOperand (0 ), v0);
256
+ mapping.map (reduxOp->getOperand (1 ), v1);
257
+ Value redx = rewriter.clone (*reduxOp, mapping)->getResult (0 );
258
+ return redx;
259
+ }
260
+
261
+ // Apply warp reduction across lanes using llvm intrinsics in GFX950.
262
+ // The input acc has the partial accumulated values from reduction within
263
+ // threads. The output acc has the final accumulated values.
264
+ //
265
+ // Two special cases are supported:
266
+ // When numLaneToReduce == 2 && interleave == 32:
267
+ // step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
268
+ // the row 0 and 1 of the copy of acc
269
+ // step 2: apply reduction to the result values to get final result
270
+ // When numLaneToReduce == 4 && interleave == 16:
271
+ // step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
272
+ // the row 0 and 1 of the copy of acc
273
+ // step 2: apply reduction to the result values to get the partial result
274
+ // step 3: use permlane16_swap() to swap the odd and even rows of
275
+ // the partial results
276
+ // step 4: apply reduction to get the final results
277
+ static bool warpReduceSwap16or32 (RewriterBase &rewriter, Location loc,
278
+ SmallVector<Value> &acc, triton::ReduceOp op,
279
+ unsigned numLaneToReduce,
280
+ unsigned interleave) {
281
+ Operation *reduxOp = op.getSingleCombiner ();
282
+ if (!reduxOp)
283
+ return false ;
284
+
285
+ bool mfma32Case = numLaneToReduce == 2 && interleave == 32 ;
286
+ bool mfma16Case = numLaneToReduce == 4 && interleave == 16 ;
287
+ if (!(mfma32Case || mfma16Case))
288
+ return false ;
289
+
290
+ Value val = acc[0 ];
291
+ unsigned bits = val.getType ().getIntOrFloatBitWidth ();
292
+ if (bits > 32 )
293
+ return false ;
294
+
295
+ StringRef intrinsic = " llvm.amdgcn.permlane32.swap" ;
296
+ for (auto i = 0 ; i < acc.size (); i++) {
297
+ Value redx = permuteAndReduce (rewriter, loc, intrinsic, acc[i], reduxOp);
298
+
299
+ if (mfma16Case) {
300
+ intrinsic = " llvm.amdgcn.permlane16.swap" ;
301
+ redx = permuteAndReduce (rewriter, loc, intrinsic, redx, reduxOp);
302
+ }
303
+
304
+ acc[i] = redx;
305
+ }
306
+ return true ;
307
+ }
308
+
228
309
bool TargetInfo::warpReduce (RewriterBase &rewriter, Location loc,
229
310
SmallVector<Value> &acc, triton::ReduceOp op,
230
311
unsigned numLaneToReduce,
231
312
unsigned interleave) const {
232
313
auto b = TritonLLVMOpBuilder (loc, rewriter);
233
314
315
+ if (isCDNA () && getISAFamily () == ISAFamily::CDNA4 &&
316
+ warpReduceSwap16or32 (rewriter, loc, acc, op, numLaneToReduce, interleave))
317
+ return true ;
234
318
if (numLaneToReduce != getWarpSize ())
235
319
return false ;
236
320
if (isCDNA () && getISAFamily () == ISAFamily::CDNA1)
0 commit comments