@@ -224,12 +224,96 @@ static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
224224 return toVal;
225225}
226226
227+ // Permute lanes of the input val and apply reduction to permuted values.
228+ static Value permuteAndReduce (RewriterBase &rewriter, Location loc,
229+ StringRef intrinsic, Value val,
230+ Operation *reduxOp) {
231+ Type valType = val.getType ();
232+ assert (valType.getIntOrFloatBitWidth () <= 32 );
233+
234+ Type actualType = valType;
235+ if (!valType.isInteger (32 ))
236+ actualType = castToAndSExtInt (rewriter, loc, val, valType, 32 );
237+
238+ auto b = TritonLLVMOpBuilder (loc, rewriter);
239+ Value falseVal = b.false_val ();
240+ MLIRContext *ctx = rewriter.getContext ();
241+ Type retType = struct_ty ({i32_ty, i32_ty});
242+ Value perm =
243+ LLVM::createLLVMIntrinsicCallOp (rewriter, loc, intrinsic, retType,
244+ ValueRange{val, val, falseVal, falseVal})
245+ ->getResult (0 );
246+ Value v0 = b.extract_val (i32_ty, perm, 0 );
247+ Value v1 = b.extract_val (i32_ty, perm, 1 );
248+
249+ if (!valType.isInteger (32 )) {
250+ v0 = truncAndCastFromInt (rewriter, loc, v0, valType, 32 );
251+ v1 = truncAndCastFromInt (rewriter, loc, v1, valType, 32 );
252+ }
253+ IRMapping mapping;
254+ mapping.map (reduxOp->getOperand (0 ), v0);
255+ mapping.map (reduxOp->getOperand (1 ), v1);
256+ Value redx = rewriter.clone (*reduxOp, mapping)->getResult (0 );
257+ return redx;
258+ }
259+
260+ // Apply warp reduction across lanes using llvm intrinsics in GFX950.
261+ // The input acc has the partial accumulated values from reduction within
262+ // threads. The output acc has the final accumulated values.
263+ //
264+ // Two special cases are supported:
265+ // When numLaneToReduce == 2 && interleave == 32:
266+ // step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
267+ // the row 0 and 1 of the copy of acc
268+ // step 2: apply reduction to the result values to get final result
269+ // When numLaneToReduce == 4 && interleave == 16:
270+ // step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
271+ // the row 0 and 1 of the copy of acc
272+ // step 2: apply reduction to the result values to get the partial result
273+ // step 3: use permlane16_swap() to swap the odd and even rows of
274+ // the partial results
275+ // step 4: apply reduction to get the final results
276+ static bool warpReduceSwap16or32 (RewriterBase &rewriter, Location loc,
277+ SmallVector<Value> &acc, triton::ReduceOp op,
278+ unsigned numLaneToReduce,
279+ unsigned interleave) {
280+ Operation *reduxOp = op.getSingleCombiner ();
281+ if (!reduxOp)
282+ return false ;
283+
284+ bool mfma32Case = numLaneToReduce == 2 && interleave == 32 ;
285+ bool mfma16Case = numLaneToReduce == 4 && interleave == 16 ;
286+ if (!(mfma32Case || mfma16Case))
287+ return false ;
288+
289+ Value val = acc[0 ];
290+ unsigned bits = val.getType ().getIntOrFloatBitWidth ();
291+ if (bits > 32 )
292+ return false ;
293+
294+ StringRef intrinsic = " llvm.amdgcn.permlane32.swap" ;
295+ for (auto i = 0 ; i < acc.size (); i++) {
296+ Value redx = permuteAndReduce (rewriter, loc, intrinsic, acc[i], reduxOp);
297+
298+ if (mfma16Case) {
299+ intrinsic = " llvm.amdgcn.permlane16.swap" ;
300+ redx = permuteAndReduce (rewriter, loc, intrinsic, redx, reduxOp);
301+ }
302+
303+ acc[i] = redx;
304+ }
305+ return true ;
306+ }
307+
227308bool TargetInfo::warpReduce (RewriterBase &rewriter, Location loc,
228309 SmallVector<Value> &acc, triton::ReduceOp op,
229310 unsigned numLaneToReduce,
230311 unsigned interleave) const {
231312 auto b = TritonLLVMOpBuilder (loc, rewriter);
232313
314+ if (isCDNA () && getISAFamily () == ISAFamily::CDNA4 &&
315+ warpReduceSwap16or32 (rewriter, loc, acc, op, numLaneToReduce, interleave))
316+ return true ;
233317 if (numLaneToReduce != getWarpSize ())
234318 return false ;
235319 if (isCDNA () && getISAFamily () == ISAFamily::CDNA1)
0 commit comments