Skip to content

Commit 5a74a4a

Browse files
authored
[Attributor] Take the address space from addrspacecast directly (#108258)
Currently `AAAddressSpace` relies on identifying the address spaces of all underlying objects. However, it might infer sub-optimal address space when the underlying object is a function argument. In `AMDGPUPromoteKernelArgumentsPass`, the promotion of a pointer kernel argument is by adding a series of `addrspacecast` instructions (as shown below), and hoping `InferAddressSpacePass` can pick it up and do the rewriting accordingly. Before promotion: ``` define amdgpu_kernel void @kernel(ptr %to_be_promoted) { %val = load i32, ptr %to_be_promoted ... ret void } ``` After promotion: ``` define amdgpu_kernel void @kernel(ptr %to_be_promoted) { %ptr.cast.0 = addrspace cast ptr % to_be_promoted to ptr addrspace(1) %ptr.cast.1 = addrspace cast ptr addrspace(1) %ptr.cast.0 to ptr # all the use of %to_be_promoted will use %ptr.cast.1 %val = load i32, ptr %ptr.cast.1 ... ret void } ``` When `AAAddressSpace` analyzes the code after promotion, it will take `%to_be_promoted` as the underlying object of `%ptr.cast.1`, and use its address space (which is 0) as its final address space, thus simply do nothing in `manifest`. The attributor framework will them eliminate the address space cast from 0 to 1 and back to 0, and replace `%ptr.cast.1` with `%to_be_promoted`, which basically reverts all changes by `AMDGPUPromoteKernelArgumentsPass`. IMHO I'm not sure if `AMDGPUPromoteKernelArgumentsPass` promotes the argument in a proper way. To improve the handling of this case, this PR adds an extra handling when iterating over all underlying objects. If an underlying object is a function argument, it means it reaches a terminal such that we can't futher deduce its underlying object further. In this case, we check all uses of the argument. If they are all `addrspacecast` instructions and their destination address spaces are same, we take the destination address space. Fixes: SWDEV-482640.
1 parent 03229e7 commit 5a74a4a

File tree

2 files changed

+80
-14
lines changed

2 files changed

+80
-14
lines changed

llvm/lib/Transforms/IPO/AttributorAttributes.cpp

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12583,16 +12583,36 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1258312583
}
1258412584

1258512585
ChangeStatus updateImpl(Attributor &A) override {
12586+
unsigned FlatAS = A.getInfoCache().getFlatAddressSpace().value();
1258612587
uint32_t OldAddressSpace = AssumedAddressSpace;
12587-
auto *AUO = A.getOrCreateAAFor<AAUnderlyingObjects>(getIRPosition(), this,
12588-
DepClassTy::REQUIRED);
12589-
auto Pred = [&](Value &Obj) {
12588+
12589+
auto CheckAddressSpace = [&](Value &Obj) {
1259012590
if (isa<UndefValue>(&Obj))
1259112591
return true;
12592+
// If an argument in flat address space only has addrspace cast uses, and
12593+
// those casts are same, then we take the dst addrspace.
12594+
if (auto *Arg = dyn_cast<Argument>(&Obj)) {
12595+
if (Arg->getType()->getPointerAddressSpace() == FlatAS) {
12596+
unsigned CastAddrSpace = FlatAS;
12597+
for (auto *U : Arg->users()) {
12598+
auto *ASCI = dyn_cast<AddrSpaceCastInst>(U);
12599+
if (!ASCI)
12600+
return takeAddressSpace(Obj.getType()->getPointerAddressSpace());
12601+
if (CastAddrSpace != FlatAS &&
12602+
CastAddrSpace != ASCI->getDestAddressSpace())
12603+
return false;
12604+
CastAddrSpace = ASCI->getDestAddressSpace();
12605+
}
12606+
if (CastAddrSpace != FlatAS)
12607+
return takeAddressSpace(CastAddrSpace);
12608+
}
12609+
}
1259212610
return takeAddressSpace(Obj.getType()->getPointerAddressSpace());
1259312611
};
1259412612

12595-
if (!AUO->forallUnderlyingObjects(Pred))
12613+
auto *AUO = A.getOrCreateAAFor<AAUnderlyingObjects>(getIRPosition(), this,
12614+
DepClassTy::REQUIRED);
12615+
if (!AUO->forallUnderlyingObjects(CheckAddressSpace))
1259612616
return indicatePessimisticFixpoint();
1259712617

1259812618
return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED
@@ -12601,17 +12621,21 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1260112621

1260212622
/// See AbstractAttribute::manifest(...).
1260312623
ChangeStatus manifest(Attributor &A) override {
12604-
if (getAddressSpace() == InvalidAddressSpace ||
12605-
getAddressSpace() == getAssociatedType()->getPointerAddressSpace())
12624+
unsigned NewAS = getAddressSpace();
12625+
12626+
if (NewAS == InvalidAddressSpace ||
12627+
NewAS == getAssociatedType()->getPointerAddressSpace())
1260612628
return ChangeStatus::UNCHANGED;
1260712629

12630+
unsigned FlatAS = A.getInfoCache().getFlatAddressSpace().value();
12631+
1260812632
Value *AssociatedValue = &getAssociatedValue();
12609-
Value *OriginalValue = peelAddrspacecast(AssociatedValue);
12633+
Value *OriginalValue = peelAddrspacecast(AssociatedValue, FlatAS);
1261012634

1261112635
PointerType *NewPtrTy =
12612-
PointerType::get(getAssociatedType()->getContext(), getAddressSpace());
12636+
PointerType::get(getAssociatedType()->getContext(), NewAS);
1261312637
bool UseOriginalValue =
12614-
OriginalValue->getType()->getPointerAddressSpace() == getAddressSpace();
12638+
OriginalValue->getType()->getPointerAddressSpace() == NewAS;
1261512639

1261612640
bool Changed = false;
1261712641

@@ -12671,12 +12695,19 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1267112695
return AssumedAddressSpace == AS;
1267212696
}
1267312697

12674-
static Value *peelAddrspacecast(Value *V) {
12675-
if (auto *I = dyn_cast<AddrSpaceCastInst>(V))
12676-
return peelAddrspacecast(I->getPointerOperand());
12698+
static Value *peelAddrspacecast(Value *V, unsigned FlatAS) {
12699+
if (auto *I = dyn_cast<AddrSpaceCastInst>(V)) {
12700+
assert(I->getSrcAddressSpace() != FlatAS &&
12701+
"there should not be flat AS -> non-flat AS");
12702+
return I->getPointerOperand();
12703+
}
1267712704
if (auto *C = dyn_cast<ConstantExpr>(V))
12678-
if (C->getOpcode() == Instruction::AddrSpaceCast)
12679-
return peelAddrspacecast(C->getOperand(0));
12705+
if (C->getOpcode() == Instruction::AddrSpaceCast) {
12706+
assert(C->getOperand(0)->getType()->getPointerAddressSpace() !=
12707+
FlatAS &&
12708+
"there should not be flat AS -> non-flat AS X");
12709+
return C->getOperand(0);
12710+
}
1268012711
return V;
1268112712
}
1268212713
};

llvm/test/CodeGen/AMDGPU/aa-as-infer.ll

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,38 @@ define void @foo(ptr addrspace(3) %val) {
243243
ret void
244244
}
245245

246+
define void @kernel_argument_promotion_pattern_intra_procedure(ptr %p, i32 %val) {
247+
; CHECK-LABEL: define void @kernel_argument_promotion_pattern_intra_procedure(
248+
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) #[[ATTR0]] {
249+
; CHECK-NEXT: [[P_CAST_0:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
250+
; CHECK-NEXT: store i32 [[VAL]], ptr addrspace(1) [[P_CAST_0]], align 4
251+
; CHECK-NEXT: ret void
252+
;
253+
%p.cast.0 = addrspacecast ptr %p to ptr addrspace(1)
254+
%p.cast.1 = addrspacecast ptr addrspace(1) %p.cast.0 to ptr
255+
store i32 %val, ptr %p.cast.1
256+
ret void
257+
}
258+
259+
define internal void @use_argument_after_promotion(ptr %p, i32 %val) {
260+
; CHECK-LABEL: define internal void @use_argument_after_promotion(
261+
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) #[[ATTR0]] {
262+
; CHECK-NEXT: [[TMP1:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
263+
; CHECK-NEXT: store i32 [[VAL]], ptr addrspace(1) [[TMP1]], align 4
264+
; CHECK-NEXT: ret void
265+
;
266+
store i32 %val, ptr %p
267+
ret void
268+
}
269+
270+
define void @kernel_argument_promotion_pattern_inter_procedure(ptr %p, i32 %val) {
271+
; CHECK-LABEL: define void @kernel_argument_promotion_pattern_inter_procedure(
272+
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) #[[ATTR0]] {
273+
; CHECK-NEXT: call void @use_argument_after_promotion(ptr [[P]], i32 [[VAL]])
274+
; CHECK-NEXT: ret void
275+
;
276+
%p.cast.0 = addrspacecast ptr %p to ptr addrspace(1)
277+
%p.cast.1 = addrspacecast ptr addrspace(1) %p.cast.0 to ptr
278+
call void @use_argument_after_promotion(ptr %p.cast.1, i32 %val)
279+
ret void
280+
}

0 commit comments

Comments
 (0)