@@ -225,12 +225,96 @@ static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
225225 return toVal;
226226}
227227
228+ // Permute lanes of the input val and apply reduction to permuted values.
229+ static Value permuteAndReduce (RewriterBase &rewriter, Location loc,
230+ StringRef intrinsic, Value val,
231+ Operation *reduxOp) {
232+ Type valType = val.getType ();
233+ assert (valType.getIntOrFloatBitWidth () <= 32 );
234+
235+ Type actualType = valType;
236+ if (!valType.isInteger (32 ))
237+ actualType = castToAndSExtInt (rewriter, loc, val, valType, 32 );
238+
239+ auto b = TritonLLVMOpBuilder (loc, rewriter);
240+ Value falseVal = b.false_val ();
241+ MLIRContext *ctx = rewriter.getContext ();
242+ Type retType = struct_ty ({i32_ty, i32_ty});
243+ Value perm =
244+ LLVM::createLLVMIntrinsicCallOp (rewriter, loc, intrinsic, retType,
245+ ValueRange{val, val, falseVal, falseVal})
246+ ->getResult (0 );
247+ Value v0 = b.extract_val (i32_ty, perm, 0 );
248+ Value v1 = b.extract_val (i32_ty, perm, 1 );
249+
250+ if (!valType.isInteger (32 )) {
251+ v0 = truncAndCastFromInt (rewriter, loc, v0, valType, 32 );
252+ v1 = truncAndCastFromInt (rewriter, loc, v1, valType, 32 );
253+ }
254+ IRMapping mapping;
255+ mapping.map (reduxOp->getOperand (0 ), v0);
256+ mapping.map (reduxOp->getOperand (1 ), v1);
257+ Value redx = rewriter.clone (*reduxOp, mapping)->getResult (0 );
258+ return redx;
259+ }
260+
261+ // Apply warp reduction across lanes using llvm intrinsics in GFX950.
262+ // The input acc has the partial accumulated values from reduction within
263+ // threads. The output acc has the final accumulated values.
264+ //
265+ // Two special cases are supported:
266+ // When numLaneToReduce == 2 && interleave == 32:
267+ // step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
268+ // the row 0 and 1 of the copy of acc
269+ // step 2: apply reduction to the result values to get final result
270+ // When numLaneToReduce == 4 && interleave == 16:
271+ // step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
272+ // the row 0 and 1 of the copy of acc
273+ // step 2: apply reduction to the result values to get the partial result
274+ // step 3: use permlane16_swap() to swap the odd and even rows of
275+ // the partial results
276+ // step 4: apply reduction to get the final results
277+ static bool warpReduceSwap16or32 (RewriterBase &rewriter, Location loc,
278+ SmallVector<Value> &acc, triton::ReduceOp op,
279+ unsigned numLaneToReduce,
280+ unsigned interleave) {
281+ Operation *reduxOp = op.getSingleCombiner ();
282+ if (!reduxOp)
283+ return false ;
284+
285+ bool mfma32Case = numLaneToReduce == 2 && interleave == 32 ;
286+ bool mfma16Case = numLaneToReduce == 4 && interleave == 16 ;
287+ if (!(mfma32Case || mfma16Case))
288+ return false ;
289+
290+ Value val = acc[0 ];
291+ unsigned bits = val.getType ().getIntOrFloatBitWidth ();
292+ if (bits > 32 )
293+ return false ;
294+
295+ StringRef intrinsic = " llvm.amdgcn.permlane32.swap" ;
296+ for (auto i = 0 ; i < acc.size (); i++) {
297+ Value redx = permuteAndReduce (rewriter, loc, intrinsic, acc[i], reduxOp);
298+
299+ if (mfma16Case) {
300+ intrinsic = " llvm.amdgcn.permlane16.swap" ;
301+ redx = permuteAndReduce (rewriter, loc, intrinsic, redx, reduxOp);
302+ }
303+
304+ acc[i] = redx;
305+ }
306+ return true ;
307+ }
308+
228309bool TargetInfo::warpReduce (RewriterBase &rewriter, Location loc,
229310 SmallVector<Value> &acc, triton::ReduceOp op,
230311 unsigned numLaneToReduce,
231312 unsigned interleave) const {
232313 auto b = TritonLLVMOpBuilder (loc, rewriter);
233314
315+ if (isCDNA () && getISAFamily () == ISAFamily::CDNA4 &&
316+ warpReduceSwap16or32 (rewriter, loc, acc, op, numLaneToReduce, interleave))
317+ return true ;
234318 if (numLaneToReduce != getWarpSize ())
235319 return false ;
236320 if (isCDNA () && getISAFamily () == ISAFamily::CDNA1)
0 commit comments