Skip to content

Commit 55f615e

Browse files
authored
[InstCombine] Iterative replacement in PtrReplacer (llvm#145410) (llvm#2803)
2 parents 2ba33cb + 42894d3 commit 55f615e

File tree

2 files changed

+175
-69
lines changed

2 files changed

+175
-69
lines changed

llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,10 @@ class PointerReplacer {
244244
void replacePointer(Value *V);
245245

246246
private:
247-
bool collectUsersRecursive(Instruction &I);
248247
void replace(Instruction *I);
249-
Value *getReplacement(Value *I);
248+
Value *getReplacement(Value *V) const { return WorkMap.lookup(V); }
250249
bool isAvailable(Instruction *I) const {
251-
return I == &Root || Worklist.contains(I);
250+
return I == &Root || UsersToReplace.contains(I);
252251
}
253252

254253
bool isEqualOrValidAddrSpaceCast(const Instruction *I,
@@ -260,8 +259,7 @@ class PointerReplacer {
260259
return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS);
261260
}
262261

263-
SmallPtrSet<Instruction *, 32> ValuesToRevisit;
264-
SmallSetVector<Instruction *, 4> Worklist;
262+
SmallSetVector<Instruction *, 32> UsersToReplace;
265263
MapVector<Value *, Value *> WorkMap;
266264
InstCombinerImpl &IC;
267265
Instruction &Root;
@@ -270,80 +268,119 @@ class PointerReplacer {
270268
} // end anonymous namespace
271269

272270
bool PointerReplacer::collectUsers() {
273-
if (!collectUsersRecursive(Root))
274-
return false;
275-
276-
// Ensure that all outstanding (indirect) users of I
277-
// are inserted into the Worklist. Return false
278-
// otherwise.
279-
return llvm::set_is_subset(ValuesToRevisit, Worklist);
280-
}
271+
SmallVector<Instruction *> Worklist;
272+
SmallSetVector<Instruction *, 32> ValuesToRevisit;
273+
274+
auto PushUsersToWorklist = [&](Instruction *Inst) {
275+
for (auto *U : Inst->users())
276+
if (auto *I = dyn_cast<Instruction>(U))
277+
if (!isAvailable(I) && !ValuesToRevisit.contains(I))
278+
Worklist.emplace_back(I);
279+
};
281280

282-
bool PointerReplacer::collectUsersRecursive(Instruction &I) {
283-
for (auto *U : I.users()) {
284-
auto *Inst = cast<Instruction>(&*U);
281+
PushUsersToWorklist(&Root);
282+
while (!Worklist.empty()) {
283+
Instruction *Inst = Worklist.pop_back_val();
285284
if (auto *Load = dyn_cast<LoadInst>(Inst)) {
286285
if (Load->isVolatile())
287286
return false;
288-
Worklist.insert(Load);
287+
UsersToReplace.insert(Load);
289288
} else if (auto *PHI = dyn_cast<PHINode>(Inst)) {
290-
// All incoming values must be instructions for replacability
291-
if (any_of(PHI->incoming_values(),
292-
[](Value *V) { return !isa<Instruction>(V); }))
293-
return false;
294-
295-
// If at least one incoming value of the PHI is not in Worklist,
296-
// store the PHI for revisiting and skip this iteration of the
297-
// loop.
298-
if (any_of(PHI->incoming_values(), [this](Value *V) {
299-
return !isAvailable(cast<Instruction>(V));
289+
/// TODO: Handle poison and null pointers for PHI and select.
290+
// If all incoming values are available, mark this PHI as
291+
// replacable and push it's users into the worklist.
292+
bool IsReplaceable = true;
293+
if (all_of(PHI->incoming_values(), [&](Value *V) {
294+
if (!isa<Instruction>(V))
295+
return IsReplaceable = false;
296+
return isAvailable(cast<Instruction>(V));
300297
})) {
301-
ValuesToRevisit.insert(Inst);
298+
UsersToReplace.insert(PHI);
299+
PushUsersToWorklist(PHI);
302300
continue;
303301
}
304302

305-
Worklist.insert(PHI);
306-
if (!collectUsersRecursive(*PHI))
307-
return false;
308-
} else if (auto *SI = dyn_cast<SelectInst>(Inst)) {
309-
if (!isa<Instruction>(SI->getTrueValue()) ||
310-
!isa<Instruction>(SI->getFalseValue()))
303+
// Either an incoming value is not an instruction or not all
304+
// incoming values are available. If this PHI was already
305+
// visited prior to this iteration, return false.
306+
if (!IsReplaceable || !ValuesToRevisit.insert(PHI))
311307
return false;
312308

313-
if (!isAvailable(cast<Instruction>(SI->getTrueValue())) ||
314-
!isAvailable(cast<Instruction>(SI->getFalseValue()))) {
315-
ValuesToRevisit.insert(Inst);
316-
continue;
309+
// Push PHI back into the stack, followed by unavailable
310+
// incoming values.
311+
Worklist.emplace_back(PHI);
312+
for (unsigned Idx = 0; Idx < PHI->getNumIncomingValues(); ++Idx) {
313+
auto *IncomingValue = cast<Instruction>(PHI->getIncomingValue(Idx));
314+
if (UsersToReplace.contains(IncomingValue))
315+
continue;
316+
if (!ValuesToRevisit.insert(IncomingValue))
317+
return false;
318+
Worklist.emplace_back(IncomingValue);
317319
}
318-
Worklist.insert(SI);
319-
if (!collectUsersRecursive(*SI))
320-
return false;
321-
} else if (isa<GetElementPtrInst>(Inst)) {
322-
Worklist.insert(Inst);
323-
if (!collectUsersRecursive(*Inst))
320+
} else if (auto *SI = dyn_cast<SelectInst>(Inst)) {
321+
auto *TrueInst = dyn_cast<Instruction>(SI->getTrueValue());
322+
auto *FalseInst = dyn_cast<Instruction>(SI->getFalseValue());
323+
if (!TrueInst || !FalseInst)
324324
return false;
325+
326+
UsersToReplace.insert(SI);
327+
PushUsersToWorklist(SI);
328+
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(Inst)) {
329+
UsersToReplace.insert(GEP);
330+
PushUsersToWorklist(GEP);
325331
} else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) {
326332
if (MI->isVolatile())
327333
return false;
328-
Worklist.insert(Inst);
334+
UsersToReplace.insert(Inst);
329335
} else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) {
330-
Worklist.insert(Inst);
331-
if (!collectUsersRecursive(*Inst))
332-
return false;
336+
UsersToReplace.insert(Inst);
337+
PushUsersToWorklist(Inst);
333338
} else if (Inst->isLifetimeStartOrEnd()) {
334339
continue;
335340
} else {
336341
// TODO: For arbitrary uses with address space mismatches, should we check
337342
// if we can introduce a valid addrspacecast?
338-
LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *U << '\n');
343+
LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *Inst << '\n');
339344
return false;
340345
}
341346
}
342347

343348
return true;
344349
}
345350

346-
Value *PointerReplacer::getReplacement(Value *V) { return WorkMap.lookup(V); }
351+
void PointerReplacer::replacePointer(Value *V) {
352+
assert(cast<PointerType>(Root.getType()) != cast<PointerType>(V->getType()) &&
353+
"Invalid usage");
354+
WorkMap[&Root] = V;
355+
SmallVector<Instruction *> Worklist;
356+
SetVector<Instruction *> PostOrderWorklist;
357+
SmallPtrSet<Instruction *, 32> Visited;
358+
359+
// Perform a postorder traversal of the users of Root.
360+
Worklist.push_back(&Root);
361+
while (!Worklist.empty()) {
362+
Instruction *I = Worklist.back();
363+
364+
// If I has not been processed before, push each of its
365+
// replacable users into the worklist.
366+
if (Visited.insert(I).second) {
367+
for (auto *U : I->users()) {
368+
auto *UserInst = cast<Instruction>(U);
369+
if (UsersToReplace.contains(UserInst) && !Visited.contains(UserInst))
370+
Worklist.push_back(UserInst);
371+
}
372+
// Otherwise, users of I have already been pushed into
373+
// the PostOrderWorklist. Push I as well.
374+
} else {
375+
PostOrderWorklist.insert(I);
376+
Worklist.pop_back();
377+
}
378+
}
379+
380+
// Replace pointers in reverse-postorder.
381+
for (Instruction *I : reverse(PostOrderWorklist))
382+
replace(I);
383+
}
347384

348385
void PointerReplacer::replace(Instruction *I) {
349386
if (getReplacement(I))
@@ -365,13 +402,15 @@ void PointerReplacer::replace(Instruction *I) {
365402
// replacement (new value).
366403
WorkMap[NewI] = NewI;
367404
} else if (auto *PHI = dyn_cast<PHINode>(I)) {
368-
Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType();
369-
auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(),
370-
PHI->getName(), PHI->getIterator());
371-
for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I)
372-
NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)),
373-
PHI->getIncomingBlock(I));
374-
WorkMap[PHI] = NewPHI;
405+
// Create a new PHI by replacing any incoming value that is a user of the
406+
// root pointer and has a replacement.
407+
Value *V = WorkMap.lookup(PHI->getIncomingValue(0));
408+
PHI->mutateType(V ? V->getType() : PHI->getIncomingValue(0)->getType());
409+
for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) {
410+
Value *V = WorkMap.lookup(PHI->getIncomingValue(I));
411+
PHI->setIncomingValue(I, V ? V : PHI->getIncomingValue(I));
412+
}
413+
WorkMap[PHI] = PHI;
375414
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
376415
auto *V = getReplacement(GEP->getPointerOperand());
377416
assert(V && "Operand not replaced");
@@ -435,18 +474,6 @@ void PointerReplacer::replace(Instruction *I) {
435474
}
436475
}
437476

438-
void PointerReplacer::replacePointer(Value *V) {
439-
#ifndef NDEBUG
440-
auto *PT = cast<PointerType>(Root.getType());
441-
auto *NT = cast<PointerType>(V->getType());
442-
assert(PT != NT && "Invalid usage");
443-
#endif
444-
WorkMap[&Root] = V;
445-
446-
for (Instruction *Workitem : Worklist)
447-
replace(Workitem);
448-
}
449-
450477
Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
451478
if (auto *I = simplifyAllocaArraySize(*this, AI, DT))
452479
return I;
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=instcombine -S < %s | FileCheck %s
3+
4+
%struct.type = type { [256 x <2 x i64>] }
5+
@g1 = external hidden addrspace(3) global %struct.type, align 16
6+
7+
; This test requires the PtrReplacer to replace users in an RPO traversal.
8+
; Furthermore, %ptr.else need not to be replaced so it must be retained in
9+
; %ptr.sink.
10+
define <2 x i64> @func(ptr addrspace(4) byref(%struct.type) align 16 %0, i1 %cmp.0) {
11+
; CHECK-LABEL: define <2 x i64> @func(
12+
; CHECK-SAME: ptr addrspace(4) byref([[STRUCT_TYPE:%.*]]) align 16 [[TMP0:%.*]], i1 [[CMP_0:%.*]]) {
13+
; CHECK-NEXT: [[ENTRY:.*:]]
14+
; CHECK-NEXT: br i1 [[CMP_0]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
15+
; CHECK: [[IF_THEN]]:
16+
; CHECK-NEXT: [[VAL_THEN:%.*]] = addrspacecast ptr addrspace(4) [[TMP0]] to ptr
17+
; CHECK-NEXT: br label %[[SINK:.*]]
18+
; CHECK: [[IF_ELSE]]:
19+
; CHECK-NEXT: [[PTR_ELSE:%.*]] = load ptr, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g1, i32 32), align 16
20+
; CHECK-NEXT: br label %[[SINK]]
21+
; CHECK: [[SINK]]:
22+
; CHECK-NEXT: [[PTR_SINK:%.*]] = phi ptr [ [[PTR_ELSE]], %[[IF_ELSE]] ], [ [[VAL_THEN]], %[[IF_THEN]] ]
23+
; CHECK-NEXT: [[VAL_SINK:%.*]] = load <2 x i64>, ptr [[PTR_SINK]], align 16
24+
; CHECK-NEXT: ret <2 x i64> [[VAL_SINK]]
25+
;
26+
entry:
27+
%coerce = alloca %struct.type, align 16, addrspace(5)
28+
call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 16 %coerce, ptr addrspace(4) align 16 %0, i64 4096, i1 false)
29+
br i1 %cmp.0, label %if.then, label %if.else
30+
31+
if.then: ; preds = %entry
32+
%ptr.then = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 0
33+
%val.then = addrspacecast ptr addrspace(5) %ptr.then to ptr
34+
br label %sink
35+
36+
if.else: ; preds = %entry
37+
%ptr.else = load ptr, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g1, i32 32), align 16
38+
%val.else = getelementptr inbounds nuw i8, ptr %ptr.else, i64 0
39+
br label %sink
40+
41+
sink:
42+
%ptr.sink = phi ptr [ %val.else, %if.else ], [ %val.then, %if.then ]
43+
%val.sink = load <2 x i64>, ptr %ptr.sink, align 16
44+
ret <2 x i64> %val.sink
45+
}
46+
47+
define <2 x i64> @func_phi_loop(ptr addrspace(4) byref(%struct.type) align 16 %0, i1 %cmp.0) {
48+
; CHECK-LABEL: define <2 x i64> @func_phi_loop(
49+
; CHECK-SAME: ptr addrspace(4) byref([[STRUCT_TYPE:%.*]]) align 16 [[TMP0:%.*]], i1 [[CMP_0:%.*]]) {
50+
; CHECK-NEXT: [[ENTRY:.*]]:
51+
; CHECK-NEXT: [[VAL_0:%.*]] = addrspacecast ptr addrspace(4) [[TMP0]] to ptr
52+
; CHECK-NEXT: br label %[[LOOP:.*]]
53+
; CHECK: [[LOOP]]:
54+
; CHECK-NEXT: [[PTR_PHI_R:%.*]] = phi ptr [ [[PTR_1:%.*]], %[[LOOP]] ], [ [[VAL_0]], %[[ENTRY]] ]
55+
; CHECK-NEXT: [[PTR_1]] = load ptr, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g1, i32 32), align 16
56+
; CHECK-NEXT: br i1 [[CMP_0]], label %[[LOOP]], label %[[SINK:.*]]
57+
; CHECK: [[SINK]]:
58+
; CHECK-NEXT: [[VAL_SINK:%.*]] = load <2 x i64>, ptr [[PTR_PHI_R]], align 16
59+
; CHECK-NEXT: ret <2 x i64> [[VAL_SINK]]
60+
;
61+
entry:
62+
%coerce = alloca %struct.type, align 16, addrspace(5)
63+
call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 16 %coerce, ptr addrspace(4) align 16 %0, i64 4096, i1 false)
64+
%ptr.0 = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 0
65+
%val.0 = addrspacecast ptr addrspace(5) %ptr.0 to ptr
66+
br label %loop
67+
68+
loop:
69+
%ptr.phi = phi ptr [ %val.1, %loop ], [ %val.0, %entry ]
70+
%ptr.1 = load ptr, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g1, i32 32), align 16
71+
%val.1 = getelementptr inbounds nuw i8, ptr %ptr.1, i64 0
72+
br i1 %cmp.0, label %loop, label %sink
73+
74+
sink:
75+
%val.sink = load <2 x i64>, ptr %ptr.phi, align 16
76+
ret <2 x i64> %val.sink
77+
}
78+
79+
declare void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) noalias writeonly captures(none), ptr addrspace(4) noalias readonly captures(none), i64, i1 immarg) #0

0 commit comments

Comments
 (0)