Skip to content

Commit 3f721ad

Browse files
committed
handle memcpy chain
1 parent 4ced462 commit 3f721ad

File tree

2 files changed

+228
-5
lines changed

2 files changed

+228
-5
lines changed

llvm/lib/Target/AMDGPU/AMDGPUVectorIdiom.cpp

Lines changed: 187 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "AMDGPU.h"
4747

4848
#include "llvm/ADT/SmallVector.h"
49+
#include "llvm/ADT/SmallPtrSet.h"
4950
#include "llvm/Analysis/AssumptionCache.h"
5051
#include "llvm/Analysis/Loads.h"
5152
#include "llvm/Analysis/LoopInfo.h"
@@ -65,6 +66,7 @@
6566
#include "llvm/Support/Debug.h"
6667
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
6768
#include "llvm/Transforms/Utils/Local.h"
69+
#include "llvm/IR/ValueHandle.h"
6870
#include <atomic>
6971
#include <cstdlib>
7072

@@ -215,6 +217,127 @@ struct AMDGPUVectorIdiomImpl {
215217

216218
AMDGPUVectorIdiomImpl(unsigned MaxBytes) : MaxBytes(MaxBytes) {}
217219

220+
// Returns true if the given intrinsic is an allowed lifetime marker.
221+
static bool isAllowedLifetimeIntrinsic(Instruction *I) {
222+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
223+
return II->getIntrinsicID() == Intrinsic::lifetime_start ||
224+
II->getIntrinsicID() == Intrinsic::lifetime_end;
225+
}
226+
return false;
227+
}
228+
229+
// Explore pointer casts/GEPs reachable from BasePtr, collecting all
230+
// derived pointers. This is a small, bounded exploration since we only
231+
// follow casts/GEPs.
232+
static void collectDerivedPointers(Value *BasePtr,
233+
SmallVectorImpl<Value *> &Derived) {
234+
SmallVector<Value *, 16> Worklist;
235+
SmallPtrSet<Value *, 32> Visited;
236+
Worklist.push_back(BasePtr);
237+
238+
while (!Worklist.empty()) {
239+
Value *Cur = Worklist.pop_back_val();
240+
if (!Visited.insert(Cur).second)
241+
continue;
242+
Derived.push_back(Cur);
243+
244+
for (User *U : Cur->users()) {
245+
if (auto *BC = dyn_cast<BitCastInst>(U)) {
246+
Worklist.push_back(BC);
247+
continue;
248+
}
249+
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(U)) {
250+
Worklist.push_back(ASC);
251+
continue;
252+
}
253+
if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
254+
// Only consider in-bounds GEPs to be conservative.
255+
if (GEP->isInBounds())
256+
Worklist.push_back(GEP);
257+
continue;
258+
}
259+
}
260+
}
261+
}
262+
263+
// Attempts to find a simple memcpy chain where this memcpy (Producer) writes
264+
// to a temporary (TmpPtr), and a later memcpy (Consumer) immediately copies
265+
// from the same temporary to some final destination. The chain is considered
266+
// simple if all uses of TmpPtr (and its derived bitcasts/GEPs) are limited to
267+
// exactly these two memcpy operations (the Producer writing TmpPtr, and a
268+
// single Consumer reading TmpPtr), plus optional lifetime intrinsics and
269+
// further pointer casts/GEPs. If found and Producer dominates Consumer, the
270+
// Consumer is returned. Otherwise returns null.
271+
MemCpyInst *findMemcpyChainConsumer(MemCpyInst &Producer, Value *TmpPtr,
272+
uint64_t N,
273+
const DominatorTree *DT) {
274+
Value *Base = TmpPtr->stripPointerCasts();
275+
SmallVector<Value *, 16> Ptrs;
276+
collectDerivedPointers(Base, Ptrs);
277+
278+
SmallVector<MemCpyInst *, 2> MemCpyUsers;
279+
for (Value *P : Ptrs) {
280+
for (User *U : P->users()) {
281+
if (auto *I = dyn_cast<Instruction>(U)) {
282+
if (isAllowedLifetimeIntrinsic(I))
283+
continue;
284+
}
285+
286+
if (auto *MC = dyn_cast<MemCpyInst>(U)) {
287+
MemCpyUsers.push_back(MC);
288+
continue;
289+
}
290+
291+
if (isa<BitCastInst>(U) || isa<AddrSpaceCastInst>(U) || isa<GetElementPtrInst>(U))
292+
continue; // Covered by Derived pointers
293+
294+
// Any other use (loads/stores/calls/etc.) makes the chain non-simple.
295+
return nullptr;
296+
}
297+
}
298+
299+
// We expect exactly two memcpys in the simple chain: the producer 'Producer'
300+
// that writes TmpPtr, and a consumer that reads TmpPtr.
301+
MemCpyInst *Consumer = nullptr;
302+
for (MemCpyInst *MC : MemCpyUsers) {
303+
// Length must be constant and match N to be a simple forward copy.
304+
auto *LenCI = dyn_cast<ConstantInt>(MC->getLength());
305+
if (!LenCI || LenCI->getLimitedValue() != N)
306+
return nullptr;
307+
308+
Value *SrcMC = MC->getRawSource();
309+
Value *DstMC = MC->getRawDest();
310+
311+
bool SrcFromTmp = SrcMC->stripPointerCasts() == Base;
312+
bool DstIsTmp = DstMC->stripPointerCasts() == Base;
313+
314+
if (MC == &Producer) {
315+
// Producer must be the one writing to TmpPtr.
316+
if (!DstIsTmp)
317+
return nullptr;
318+
continue;
319+
}
320+
321+
// Any other memcpy must be the consumer reading from TmpPtr.
322+
if (!SrcFromTmp)
323+
return nullptr;
324+
325+
if (Consumer)
326+
return nullptr; // More than one consumer
327+
Consumer = MC;
328+
}
329+
330+
if (!Consumer)
331+
return nullptr;
332+
333+
// Producer must dominate Consumer so that values computed at Producer
334+
// (loads/select) can be used at the Consumer insertion point.
335+
if (DT && !DT->dominates(&Producer, Consumer))
336+
return nullptr;
337+
338+
return Consumer;
339+
}
340+
218341
// Rewrites memcpy when the source is a select of pointers. Prefers a
219342
// value-level select (two loads + select + one store) if speculative loads
220343
// are safe. Otherwise, falls back to a guarded CFG split with two memcpy
@@ -306,6 +429,47 @@ struct AMDGPUVectorIdiomImpl {
306429
bothArmsSafeToSpeculateLoads(A, Bv, N, AlignAB, DL, AC, DT, &MT);
307430

308431
if (CanSpeculate) {
432+
// First, check if this memcpy writes to a temporary that is immediately
433+
// copied again by a following memcpy to the final destination. If so,
434+
// form the value at the current location (to preserve load timing) and
435+
// emit a single store at the consumer location, erasing both memcpys.
436+
if (hasAllocaUnderlyingObject(Dst)) {
437+
if (MemCpyInst *Consumer =
438+
findMemcpyChainConsumer(MT, Dst, N, DT)) {
439+
Align ConsumerDstAlign =
440+
MaybeAlign(Consumer->getDestAlign()).valueOrOne();
441+
Align MinAlign = std::min(AlignAB, ConsumerDstAlign);
442+
443+
LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Folding memcpy chain: "
444+
<< "memcpy(tmp <- select), memcpy(dst <- tmp). "
445+
<< "Emitting value-select and single store to final dst; "
446+
<< "N=" << N << " minAlign=" << MinAlign.value()
447+
<< '\n');
448+
449+
// Compute the selected value at the original memcpy to preserve
450+
// the timing of the loads from A/B.
451+
Type *Ty = getIntOrVecTypeForSize(N, B.getContext(), MinAlign);
452+
LoadInst *LA = B.CreateAlignedLoad(Ty, A, MinAlign);
453+
LoadInst *LB = B.CreateAlignedLoad(Ty, Bv, MinAlign);
454+
Value *V = B.CreateSelect(Sel.getCondition(), LA, LB);
455+
456+
// Insert the final store right before the consumer memcpy.
457+
IRBuilder<> BC(Consumer);
458+
(void)BC.CreateAlignedStore(V, Consumer->getRawDest(),
459+
ConsumerDstAlign);
460+
461+
LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Erasing memcpy chain: \n - "
462+
<< MT << "\n - " << *Consumer << '\n');
463+
464+
incrementTransformationCounter();
465+
Consumer->eraseFromParent();
466+
MT.eraseFromParent();
467+
return true;
468+
}
469+
}
470+
471+
// No chain detected. Do the normal value-level select and store directly
472+
// to the memcpy destination.
309473
Align MinAlign = std::min(AlignAB, DstAlign);
310474
LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Rewriting memcpy(select-src) "
311475
<< "with value-level select; N=" << N
@@ -440,18 +604,29 @@ AMDGPUVectorIdiomCombinePass::run(Function &F, FunctionAnalysisManager &FAM) {
440604
}
441605
});
442606

443-
SmallVector<MemCpyInst *, 8> Worklist;
607+
SmallVector<WeakTrackingVH, 8> Worklist;
444608
for (Instruction &I : instructions(F)) {
445609
if (auto *MC = dyn_cast<MemCpyInst>(&I))
446-
Worklist.push_back(MC);
610+
Worklist.emplace_back(MC);
447611
}
448612

449613
bool Changed = false;
450614
AMDGPUVectorIdiomImpl Impl(MaxBytes);
451615

452-
for (MemCpyInst *MT : Worklist) {
616+
for (WeakTrackingVH &WH : Worklist) {
617+
auto *MT = dyn_cast_or_null<MemCpyInst>(WH);
618+
if (!MT)
619+
continue; // Was deleted by a previous transform
620+
453621
Value *Dst = MT->getRawDest();
454622
Value *Src = MT->getRawSource();
623+
624+
// Add null checks for safety
625+
if (!Dst || !Src) {
626+
LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Skip: null dst or src\n");
627+
continue;
628+
}
629+
455630
if (!isa<SelectInst>(Src) && !isa<SelectInst>(Dst))
456631
continue;
457632

@@ -532,8 +707,15 @@ AMDGPUVectorIdiomCombinePass::run(Function &F, FunctionAnalysisManager &FAM) {
532707
continue;
533708
}
534709

535-
unsigned DstAS = cast<PointerType>(Dst->getType())->getAddressSpace();
536-
unsigned SrcAS = cast<PointerType>(Src->getType())->getAddressSpace();
710+
auto *DstPTy = dyn_cast<PointerType>(Dst->getType());
711+
auto *SrcPTy = dyn_cast<PointerType>(Src->getType());
712+
if (!DstPTy || !SrcPTy) {
713+
LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Skip: non-pointer dst or src\n");
714+
continue;
715+
}
716+
717+
unsigned DstAS = DstPTy->getAddressSpace();
718+
unsigned SrcAS = SrcPTy->getAddressSpace();
537719
if (DstAS != SrcAS) {
538720
LLVM_DEBUG(dbgs() << "[AMDGPUVectorIdiom] Skip: address space mismatch "
539721
<< "(dstAS=" << DstAS << ", srcAS=" << SrcAS << ")\n");
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
; RUN: opt -amdgpu-vector-idiom-enable -mtriple=amdgcn-amd-amdhsa -passes=amdgpu-vector-idiom -S %s | FileCheck %s
2+
;
3+
; Reduced testcase for planned enhancement:
4+
; Fold a memcpy chain where the source memcpy is fed by a select-of-pointers
5+
; and the second memcpy copies from the tmp alloca to the final destination.
6+
; Expected: eliminate both memcpys by emitting a value-level select of two
7+
; vector loads followed by a single store to the final destination.
8+
;
9+
; Enhancement implemented: memcpy chain folding now works
10+
11+
declare void @llvm.memcpy.p5.p5.i64(ptr addrspace(5) nocapture writeonly, ptr addrspace(5) nocapture readonly, i64, i1 immarg)
12+
13+
; -----------------------------------------------------------------------------
14+
; memcpy(tmp <- select(pa, pb)); memcpy(dst <- tmp)
15+
; Expect: load <4 x i32> from pa/pb, select, store to dst. No memcpy.
16+
;
17+
define amdgpu_kernel void @memcpy_chain_src_select_elide_tmp(i1 %cond) {
18+
; CHECK-LABEL: define amdgpu_kernel void @memcpy_chain_src_select_elide_tmp(
19+
; CHECK-SAME: i1 [[COND:%.*]]) {
20+
; CHECK-NEXT: [[ENTRY:.*:]]
21+
; CHECK-NEXT: [[PA:%.*]] = alloca [4 x i32], align 16, addrspace(5)
22+
; CHECK-NEXT: [[PB:%.*]] = alloca [4 x i32], align 16, addrspace(5)
23+
; CHECK: [[DST:%.*]] = alloca [4 x i32], align 16, addrspace(5)
24+
; CHECK: [[SRC:%.*]] = select i1 [[COND]], ptr addrspace(5) [[PA]], ptr addrspace(5) [[PB]]
25+
; CHECK: [[LA:%.*]] = load <4 x i32>, ptr addrspace(5) [[PA]], align 16
26+
; CHECK: [[LB:%.*]] = load <4 x i32>, ptr addrspace(5) [[PB]], align 16
27+
; CHECK: [[SEL:%.*]] = select i1 [[COND]], <4 x i32> [[LA]], <4 x i32> [[LB]]
28+
; CHECK: store <4 x i32> [[SEL]], ptr addrspace(5) [[DST]], align 16
29+
; CHECK-NOT: call void @llvm.memcpy
30+
; CHECK: ret void
31+
;
32+
entry:
33+
%pa = alloca [4 x i32], align 16, addrspace(5)
34+
%pb = alloca [4 x i32], align 16, addrspace(5)
35+
%dst = alloca [4 x i32], align 16, addrspace(5)
36+
%tmp = alloca [4 x i32], align 16, addrspace(5)
37+
%src = select i1 %cond, ptr addrspace(5) %pa, ptr addrspace(5) %pb
38+
call void @llvm.memcpy.p5.p5.i64(ptr addrspace(5) align 16 %tmp, ptr addrspace(5) align 16 %src, i64 16, i1 false)
39+
call void @llvm.memcpy.p5.p5.i64(ptr addrspace(5) align 16 %dst, ptr addrspace(5) align 16 %tmp, i64 16, i1 false)
40+
ret void
41+
}

0 commit comments

Comments
 (0)