|
46 | 46 | #include "AMDGPU.h" |
47 | 47 |
|
48 | 48 | #include "llvm/ADT/SmallVector.h" |
| 49 | +#include "llvm/ADT/SmallPtrSet.h" |
49 | 50 | #include "llvm/Analysis/AssumptionCache.h" |
50 | 51 | #include "llvm/Analysis/Loads.h" |
51 | 52 | #include "llvm/Analysis/LoopInfo.h" |
|
65 | 66 | #include "llvm/Support/Debug.h" |
66 | 67 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
67 | 68 | #include "llvm/Transforms/Utils/Local.h" |
| 69 | +#include "llvm/IR/ValueHandle.h" |
68 | 70 | #include <atomic> |
69 | 71 | #include <cstdlib> |
70 | 72 |
|
@@ -215,6 +217,127 @@ struct AMDGPUVectorIdiomImpl { |
215 | 217 |
|
216 | 218 | AMDGPUVectorIdiomImpl(unsigned MaxBytes) : MaxBytes(MaxBytes) {} |
217 | 219 |
|
| 220 | + // Returns true if the given intrinsic is an allowed lifetime marker. |
| 221 | + static bool isAllowedLifetimeIntrinsic(Instruction *I) { |
| 222 | + if (auto *II = dyn_cast<IntrinsicInst>(I)) { |
| 223 | + return II->getIntrinsicID() == Intrinsic::lifetime_start || |
| 224 | + II->getIntrinsicID() == Intrinsic::lifetime_end; |
| 225 | + } |
| 226 | + return false; |
| 227 | + } |
| 228 | + |
| 229 | + // Explore pointer casts/GEPs reachable from BasePtr, collecting all |
| 230 | + // derived pointers. This is a small, bounded exploration since we only |
| 231 | + // follow casts/GEPs. |
| 232 | + static void collectDerivedPointers(Value *BasePtr, |
| 233 | + SmallVectorImpl<Value *> &Derived) { |
| 234 | + SmallVector<Value *, 16> Worklist; |
| 235 | + SmallPtrSet<Value *, 32> Visited; |
| 236 | + Worklist.push_back(BasePtr); |
| 237 | + |
| 238 | + while (!Worklist.empty()) { |
| 239 | + Value *Cur = Worklist.pop_back_val(); |
| 240 | + if (!Visited.insert(Cur).second) |
| 241 | + continue; |
| 242 | + Derived.push_back(Cur); |
| 243 | + |
| 244 | + for (User *U : Cur->users()) { |
| 245 | + if (auto *BC = dyn_cast<BitCastInst>(U)) { |
| 246 | + Worklist.push_back(BC); |
| 247 | + continue; |
| 248 | + } |
| 249 | + if (auto *ASC = dyn_cast<AddrSpaceCastInst>(U)) { |
| 250 | + Worklist.push_back(ASC); |
| 251 | + continue; |
| 252 | + } |
| 253 | + if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) { |
| 254 | + // Only consider in-bounds GEPs to be conservative. |
| 255 | + if (GEP->isInBounds()) |
| 256 | + Worklist.push_back(GEP); |
| 257 | + continue; |
| 258 | + } |
| 259 | + } |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + // Attempts to find a simple memcpy chain where this memcpy (Producer) writes |
| 264 | + // to a temporary (TmpPtr), and a later memcpy (Consumer) immediately copies |
| 265 | + // from the same temporary to some final destination. The chain is considered |
| 266 | + // simple if all uses of TmpPtr (and its derived bitcasts/GEPs) are limited to |
| 267 | + // exactly these two memcpy operations (the Producer writing TmpPtr, and a |
| 268 | + // single Consumer reading TmpPtr), plus optional lifetime intrinsics and |
| 269 | + // further pointer casts/GEPs. If found and Producer dominates Consumer, the |
| 270 | + // Consumer is returned. Otherwise returns null. |
| 271 | + MemCpyInst *findMemcpyChainConsumer(MemCpyInst &Producer, Value *TmpPtr, |
| 272 | + uint64_t N, |
| 273 | + const DominatorTree *DT) { |
| 274 | + Value *Base = TmpPtr->stripPointerCasts(); |
| 275 | + SmallVector<Value *, 16> Ptrs; |
| 276 | + collectDerivedPointers(Base, Ptrs); |
| 277 | + |
| 278 | + SmallVector<MemCpyInst *, 2> MemCpyUsers; |
| 279 | + for (Value *P : Ptrs) { |
| 280 | + for (User *U : P->users()) { |
| 281 | + if (auto *I = dyn_cast<Instruction>(U)) { |
| 282 | + if (isAllowedLifetimeIntrinsic(I)) |
| 283 | + continue; |
| 284 | + } |
| 285 | + |
| 286 | + if (auto *MC = dyn_cast<MemCpyInst>(U)) { |
| 287 | + MemCpyUsers.push_back(MC); |
| 288 | + continue; |
| 289 | + } |
| 290 | + |
| 291 | + if (isa<BitCastInst>(U) || isa<AddrSpaceCastInst>(U) || isa<GetElementPtrInst>(U)) |
| 292 | + continue; // Covered by Derived pointers |
| 293 | + |
| 294 | + // Any other use (loads/stores/calls/etc.) makes the chain non-simple. |
| 295 | + return nullptr; |
| 296 | + } |
| 297 | + } |
| 298 | + |
| 299 | + // We expect exactly two memcpys in the simple chain: the producer 'Producer' |
| 300 | + // that writes TmpPtr, and a consumer that reads TmpPtr. |
| 301 | + MemCpyInst *Consumer = nullptr; |
| 302 | + for (MemCpyInst *MC : MemCpyUsers) { |
| 303 | + // Length must be constant and match N to be a simple forward copy. |
| 304 | + auto *LenCI = dyn_cast<ConstantInt>(MC->getLength()); |
| 305 | + if (!LenCI || LenCI->getLimitedValue() != N) |
| 306 | + return nullptr; |
| 307 | + |
| 308 | + Value *SrcMC = MC->getRawSource(); |
| 309 | + Value *DstMC = MC->getRawDest(); |
| 310 | + |
| 311 | + bool SrcFromTmp = SrcMC->stripPointerCasts() == Base; |
| 312 | + bool DstIsTmp = DstMC->stripPointerCasts() == Base; |
| 313 | + |
| 314 | + if (MC == &Producer) { |
| 315 | + // Producer must be the one writing to TmpPtr. |
| 316 | + if (!DstIsTmp) |
| 317 | + return nullptr; |
| 318 | + continue; |
| 319 | + } |
| 320 | + |
| 321 | + // Any other memcpy must be the consumer reading from TmpPtr. |
| 322 | + if (!SrcFromTmp) |
| 323 | + return nullptr; |
| 324 | + |
| 325 | + if (Consumer) |
| 326 | + return nullptr; // More than one consumer |
| 327 | + Consumer = MC; |
| 328 | + } |
| 329 | + |
| 330 | + if (!Consumer) |
| 331 | + return nullptr; |
| 332 | + |
| 333 | + // Producer must dominate Consumer so that values computed at Producer |
| 334 | + // (loads/select) can be used at the Consumer insertion point. |
| 335 | + if (DT && !DT->dominates(&Producer, Consumer)) |
| 336 | + return nullptr; |
| 337 | + |
| 338 | + return Consumer; |
| 339 | + } |
| 340 | + |
218 | 341 | // Rewrites memcpy when the source is a select of pointers. Prefers a |
219 | 342 | // value-level select (two loads + select + one store) if speculative loads |
220 | 343 | // are safe. Otherwise, falls back to a guarded CFG split with two memcpy |
@@ -306,6 +429,47 @@ struct AMDGPUVectorIdiomImpl { |
306 | 429 | bothArmsSafeToSpeculateLoads(A, Bv, N, AlignAB, DL, AC, DT, &MT); |
307 | 430 |
|
308 | 431 | if (CanSpeculate) { |
| 432 | + // First, check if this memcpy writes to a temporary that is immediately |
| 433 | + // copied again by a following memcpy to the final destination. If so, |
| 434 | + // form the value at the current location (to preserve load timing) and |
| 435 | + // emit a single store at the consumer location, erasing both memcpys. |
| 436 | + if (hasAllocaUnderlyingObject(Dst)) { |
| 437 | + if (MemCpyInst *Consumer = |
| 438 | + findMemcpyChainConsumer(MT, Dst, N, DT)) { |
| 439 | + Align ConsumerDstAlign = |
| 440 | + MaybeAlign(Consumer->getDestAlign()).valueOrOne(); |
| 441 | + Align MinAlign = std::min(AlignAB, ConsumerDstAlign); |
| 442 | + |
| 443 | + LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Folding memcpy chain: " |
| 444 | + << "memcpy(tmp <- select), memcpy(dst <- tmp). " |
| 445 | + << "Emitting value-select and single store to final dst; " |
| 446 | + << "N=" << N << " minAlign=" << MinAlign.value() |
| 447 | + << '\n'); |
| 448 | + |
| 449 | + // Compute the selected value at the original memcpy to preserve |
| 450 | + // the timing of the loads from A/B. |
| 451 | + Type *Ty = getIntOrVecTypeForSize(N, B.getContext(), MinAlign); |
| 452 | + LoadInst *LA = B.CreateAlignedLoad(Ty, A, MinAlign); |
| 453 | + LoadInst *LB = B.CreateAlignedLoad(Ty, Bv, MinAlign); |
| 454 | + Value *V = B.CreateSelect(Sel.getCondition(), LA, LB); |
| 455 | + |
| 456 | + // Insert the final store right before the consumer memcpy. |
| 457 | + IRBuilder<> BC(Consumer); |
| 458 | + (void)BC.CreateAlignedStore(V, Consumer->getRawDest(), |
| 459 | + ConsumerDstAlign); |
| 460 | + |
| 461 | + LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Erasing memcpy chain: \n - " |
| 462 | + << MT << "\n - " << *Consumer << '\n'); |
| 463 | + |
| 464 | + incrementTransformationCounter(); |
| 465 | + Consumer->eraseFromParent(); |
| 466 | + MT.eraseFromParent(); |
| 467 | + return true; |
| 468 | + } |
| 469 | + } |
| 470 | + |
| 471 | + // No chain detected. Do the normal value-level select and store directly |
| 472 | + // to the memcpy destination. |
309 | 473 | Align MinAlign = std::min(AlignAB, DstAlign); |
310 | 474 | LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Rewriting memcpy(select-src) " |
311 | 475 | << "with value-level select; N=" << N |
@@ -440,18 +604,29 @@ AMDGPUVectorIdiomCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { |
440 | 604 | } |
441 | 605 | }); |
442 | 606 |
|
443 | | - SmallVector<MemCpyInst *, 8> Worklist; |
| 607 | + SmallVector<WeakTrackingVH, 8> Worklist; |
444 | 608 | for (Instruction &I : instructions(F)) { |
445 | 609 | if (auto *MC = dyn_cast<MemCpyInst>(&I)) |
446 | | - Worklist.push_back(MC); |
| 610 | + Worklist.emplace_back(MC); |
447 | 611 | } |
448 | 612 |
|
449 | 613 | bool Changed = false; |
450 | 614 | AMDGPUVectorIdiomImpl Impl(MaxBytes); |
451 | 615 |
|
452 | | - for (MemCpyInst *MT : Worklist) { |
| 616 | + for (WeakTrackingVH &WH : Worklist) { |
| 617 | + auto *MT = dyn_cast_or_null<MemCpyInst>(WH); |
| 618 | + if (!MT) |
| 619 | + continue; // Was deleted by a previous transform |
| 620 | + |
453 | 621 | Value *Dst = MT->getRawDest(); |
454 | 622 | Value *Src = MT->getRawSource(); |
| 623 | + |
| 624 | + // Add null checks for safety |
| 625 | + if (!Dst || !Src) { |
| 626 | + LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Skip: null dst or src\n"); |
| 627 | + continue; |
| 628 | + } |
| 629 | + |
455 | 630 | if (!isa<SelectInst>(Src) && !isa<SelectInst>(Dst)) |
456 | 631 | continue; |
457 | 632 |
|
@@ -532,8 +707,15 @@ AMDGPUVectorIdiomCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { |
532 | 707 | continue; |
533 | 708 | } |
534 | 709 |
|
535 | | - unsigned DstAS = cast<PointerType>(Dst->getType())->getAddressSpace(); |
536 | | - unsigned SrcAS = cast<PointerType>(Src->getType())->getAddressSpace(); |
| 710 | + auto *DstPTy = dyn_cast<PointerType>(Dst->getType()); |
| 711 | + auto *SrcPTy = dyn_cast<PointerType>(Src->getType()); |
| 712 | + if (!DstPTy || !SrcPTy) { |
| 713 | + LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Skip: non-pointer dst or src\n"); |
| 714 | + continue; |
| 715 | + } |
| 716 | + |
| 717 | + unsigned DstAS = DstPTy->getAddressSpace(); |
| 718 | + unsigned SrcAS = SrcPTy->getAddressSpace(); |
537 | 719 | if (DstAS != SrcAS) { |
538 | 720 | LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Skip: address space mismatch " |
539 | 721 | << "(dstAS=" << DstAS << ", srcAS=" << SrcAS << ")\n"); |
|
0 commit comments