@@ -65,6 +65,35 @@ llvm::AMDGPU::GPUKind TargetInfo::getGPUKind() const {
6565 return llvm::AMDGPU::parseArchAMDGCN (arch);
6666}
6767
68+ bool TargetInfo::isCDNA () const {
69+ switch (getISAFamily ()) {
70+ case ISAFamily::CDNA1:
71+ case ISAFamily::CDNA2:
72+ case ISAFamily::CDNA3:
73+ case ISAFamily::CDNA4:
74+ return true ;
75+ default :
76+ break ;
77+ }
78+
79+ return false ;
80+ }
81+
82+ bool TargetInfo::isRDNA () const {
83+ switch (getISAFamily ()) {
84+ case ISAFamily::RDNA1:
85+ case ISAFamily::RDNA2:
86+ case ISAFamily::RDNA3:
87+ return true ;
88+ default :
89+ break ;
90+ }
91+
92+ return false ;
93+ }
94+
95+ int TargetInfo::getWarpSize () const { return isCDNA () ? 64 : 32 ; }
96+
6897int TargetInfo::getSharedMemorySize () const {
6998 int kbytes = getISAFamily () == ISAFamily::CDNA4 ? 160 : 64 ;
7099 return kbytes * 1024 ;
@@ -200,14 +229,13 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
200229 unsigned numLaneToReduce,
201230 unsigned interleave) const {
202231 auto b = TritonLLVMOpBuilder (loc, rewriter);
203- if (numLaneToReduce != 64 )
204- return false ;
205232
206- if (!llvm::is_contained (
207- {ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4},
208- getISAFamily ())) {
233+ if (numLaneToReduce != getWarpSize ())
234+ return false ;
235+ if (isCDNA () && getISAFamily () == ISAFamily::CDNA1)
236+ return false ;
237+ if (isRDNA () && getISAFamily () != ISAFamily::RDNA3)
209238 return false ;
210- }
211239
212240 Operation *reduxOp = op.getSingleCombiner ();
213241 if (!reduxOp)
@@ -307,24 +335,43 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
307335 buf = createDppReduxOpWithBoundCtrl (valType, buf, 1 + dppCtrlRowShr,
308336 allRows, allBanks);
309337
310- // row_bcast:15 row_mask:0xa
311- buf = createDppReduxOpWithBoundCtrl (
312- valType, buf, static_cast <uint32_t >(DppCtrl::BCAST15), 0xa , allBanks);
338+ if (isCDNA ()) {
339+ // row_bcast:15 row_mask:0xa
340+ buf = createDppReduxOpWithBoundCtrl (
341+ valType, buf, static_cast <uint32_t >(DppCtrl::BCAST15), 0xa , allBanks);
313342
314- // row_bcast:31
315- buf = createDppReduxOpWithBoundCtrl (valType, buf,
316- static_cast <uint32_t >(DppCtrl::BCAST31),
317- allRows, allBanks);
343+ // row_bcast:31
344+ buf = createDppReduxOpWithBoundCtrl (
345+ valType, buf, static_cast <uint32_t >(DppCtrl::BCAST31), allRows,
346+ allBanks);
347+ } else {
348+ // RDNA doesn't have broadcast dpp mode
349+ Type actualType = castToAndSExtInt (rewriter, loc, buf, valType, 32 );
350+
351+ Value permlaneResult =
352+ LLVM::createLLVMIntrinsicCallOp (
353+ rewriter, loc, " llvm.amdgcn.permlanex16" , actualType,
354+ ValueRange{buf, buf, b.i32_val (-1 ), b.i32_val (-1 ), b.true_val (),
355+ b.false_val ()})
356+ ->getResult (0 );
357+ buf = truncAndCastFromInt (rewriter, loc, buf, valType, 32 );
358+ permlaneResult =
359+ truncAndCastFromInt (rewriter, loc, permlaneResult, valType, 32 );
360+ IRMapping mapping;
361+ mapping.map (reduxOp->getOperand (0 ), buf);
362+ mapping.map (reduxOp->getOperand (1 ), permlaneResult);
363+ buf = rewriter.clone (*reduxOp, mapping)->getResult (0 );
364+ }
318365
319366 // Similarly, we need to cast data types for readlane instruction.
320367 Type actualType = castToAndSExtInt (rewriter, loc, buf, valType, 16 );
321368
322- // Get reduction result from lane 63
369+ // Get reduction result from lane 63/31
323370 std::string intrinsic = " llvm.amdgcn.readlane" ;
324- Value result =
325- LLVM::createLLVMIntrinsicCallOp ( rewriter, loc, intrinsic, actualType,
326- ValueRange{buf, b.i32_val (63 )})
327- ->getResult (0 );
371+ Value result = LLVM::createLLVMIntrinsicCallOp (
372+ rewriter, loc, intrinsic, actualType,
373+ ValueRange{buf, b.i32_val (isCDNA () ? 63 : 31 )})
374+ ->getResult (0 );
328375
329376 result = truncAndCastFromInt (rewriter, loc, result, valType, 16 );
330377
0 commit comments