@@ -482,7 +482,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
482482 void lowerDirectToLDSLoad (
483483 RewriterBase &rewriter, Location loc, RankedTensorType srcTy,
484484 MemDescType dstTy, SmallVector<Value> loadVals, Value llDst,
485- Type resElemTy, unsigned vec,
485+ Type resElemTy, unsigned vec, triton::AMD::ISAFamily isaFamily,
486486 std::function<SmallVector<Value>(RewriterBase &, Location,
487487 ArrayRef<Value>, Value, int , VectorType)>
488488 lowerInst) const {
@@ -511,7 +511,40 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
511511 LLVM::getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
512512 auto affineOffset = smemObj.getShmemOffset (loc, rewriter, dstTy);
513513 auto maskSpanAffineOffset = SharedMemoryObject::getMaskSpanOffsets (dstTy);
514- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
514+
515+ Value laneId, warpId;
516+ if (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily) {
517+ // On GFX9, there is no dedicated hardware instruction to read `wave_id`.
518+ // The value is instead computed from `workitem.id.x`. Per the GFX9 ABI,
519+ // `workitem.id.x` is initialized in a vector register, and vector
520+ // instructions are generated for IR operations that depend on `wave_id`.
521+ //
522+ // A `v_readfirstlane` instruction is inserted at the end of these vector
523+ // sequences to transfer the value from a vector register to a scalar
524+ // register, initializing `$m0`.
525+
526+ // When this sequence occurs inside a loop, the MachineLICM pass does not
527+ // hoist it because `v_readfirstlane` is convergent. Since both
528+ // `workitem.id.x` and `wave_id` are constant at runtime, their
529+ // computation can be safely hoisted to the function entry block.
530+ auto insertPt = rewriter.saveInsertionPoint ();
531+ Operation *parentOp = insertPt.getBlock ()->getParentOp ();
532+ while (!isa<LLVM::LLVMFuncOp>(parentOp)) {
533+ parentOp = parentOp->getParentOp ();
534+ }
535+
536+ auto funcOp = cast<LLVM::LLVMFuncOp>(parentOp);
537+ rewriter.setInsertionPointToStart (&funcOp.getBody ().front ());
538+
539+ std::tie (laneId, warpId) = getLaneAndWarpId (rewriter, loc);
540+ auto call = LLVM::createLLVMIntrinsicCallOp (
541+ rewriter, loc, " llvm.amdgcn.readfirstlane" , {i32_ty}, {warpId});
542+ warpId = call.getResult (0 );
543+ rewriter.restoreInsertionPoint (insertPt);
544+ } else {
545+ std::tie (laneId, warpId) = getLaneAndWarpId (rewriter, loc);
546+ }
547+
515548 auto calcPaddedOffset = [&](Value smemOffset) {
516549 TritonLLVMOpBuilder b (loc, rewriter);
517550 auto bitwidth = dstTy.getElementTypeBitWidth ();
@@ -873,7 +906,8 @@ struct BufferLoadToLocalOpConversion
873906 };
874907
875908 lowerDirectToLDSLoad (rewriter, loc, ptrType, flatDstTy, loadVals, llDst,
876- resElemTy, vec, emitBufferLoadLds);
909+ resElemTy, vec, targetInfo.getISAFamily (),
910+ emitBufferLoadLds);
877911
878912 // Drop the result token.
879913 Value zero = LLVM::ConstantOp::create (rewriter, op.getLoc (),
@@ -999,7 +1033,8 @@ struct AsyncCopyGlobalToLocalOpConversion
9991033 };
10001034
10011035 lowerDirectToLDSLoad (rewriter, loc, srcTy, flatDstTy, loadVals, llDst,
1002- resElemTy, vec, emitGlobalLoadLds);
1036+ resElemTy, vec, targetInfo.getISAFamily (),
1037+ emitGlobalLoadLds);
10031038
10041039 // Drop the result token.
10051040 Value zero = LLVM::ConstantOp::create (rewriter, op.getLoc (),
0 commit comments