@@ -12589,15 +12589,37 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1258912589
1259012590 ChangeStatus updateImpl (Attributor &A) override {
1259112591 uint32_t OldAddressSpace = AssumedAddressSpace;
12592- auto *AUO = A.getOrCreateAAFor <AAUnderlyingObjects>(getIRPosition (), this ,
12593- DepClassTy::REQUIRED);
12594- auto Pred = [&](Value &Obj) {
12592+
12593+ auto CheckAddressSpace = [&](Value &Obj) {
1259512594 if (isa<UndefValue>(&Obj))
1259612595 return true ;
12596+ // If an argument in flat address space has addrspace cast uses, and those
12597+ // casts are same, then we take the dst addrspace.
12598+ if (auto *Arg = dyn_cast<Argument>(&Obj)) {
12599+ unsigned FlatAS =
12600+ A.getInfoCache ().getFlatAddressSpace (Arg->getParent ());
12601+ if (FlatAS != InvalidAddressSpace &&
12602+ Arg->getType ()->getPointerAddressSpace () == FlatAS) {
12603+ unsigned CastAddrSpace = FlatAS;
12604+ for (auto *U : Arg->users ()) {
12605+ auto *ASCI = dyn_cast<AddrSpaceCastInst>(U);
12606+ if (!ASCI)
12607+ continue ;
12608+ if (CastAddrSpace != FlatAS &&
12609+ CastAddrSpace != ASCI->getDestAddressSpace ())
12610+ return false ;
12611+ CastAddrSpace = ASCI->getDestAddressSpace ();
12612+ }
12613+ if (CastAddrSpace != FlatAS)
12614+ return takeAddressSpace (CastAddrSpace);
12615+ }
12616+ }
1259712617 return takeAddressSpace (Obj.getType ()->getPointerAddressSpace ());
1259812618 };
1259912619
12600- if (!AUO->forallUnderlyingObjects (Pred))
12620+ auto *AUO = A.getOrCreateAAFor <AAUnderlyingObjects>(getIRPosition (), this ,
12621+ DepClassTy::REQUIRED);
12622+ if (!AUO->forallUnderlyingObjects (CheckAddressSpace))
1260112623 return indicatePessimisticFixpoint ();
1260212624
1260312625 return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED
@@ -12606,17 +12628,23 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1260612628
1260712629 // / See AbstractAttribute::manifest(...).
1260812630 ChangeStatus manifest (Attributor &A) override {
12609- if (getAddressSpace () == InvalidAddressSpace ||
12610- getAddressSpace () == getAssociatedType ()->getPointerAddressSpace ())
12631+ unsigned NewAS = getAddressSpace ();
12632+
12633+ if (NewAS == InvalidAddressSpace ||
12634+ NewAS == getAssociatedType ()->getPointerAddressSpace ())
1261112635 return ChangeStatus::UNCHANGED;
1261212636
12637+ unsigned FlatAS =
12638+ A.getInfoCache ().getFlatAddressSpace (getAssociatedFunction ());
12639+ assert (FlatAS != InvalidAddressSpace);
12640+
1261312641 Value *AssociatedValue = &getAssociatedValue ();
12614- Value *OriginalValue = peelAddrspacecast (AssociatedValue);
12642+ Value *OriginalValue = peelAddrspacecast (AssociatedValue, FlatAS );
1261512643
1261612644 PointerType *NewPtrTy =
12617- PointerType::get (getAssociatedType ()->getContext (), getAddressSpace () );
12645+ PointerType::get (getAssociatedType ()->getContext (), NewAS );
1261812646 bool UseOriginalValue =
12619- OriginalValue->getType ()->getPointerAddressSpace () == getAddressSpace () ;
12647+ OriginalValue->getType ()->getPointerAddressSpace () == NewAS ;
1262012648
1262112649 bool Changed = false ;
1262212650
@@ -12676,12 +12704,19 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1267612704 return AssumedAddressSpace == AS;
1267712705 }
1267812706
12679- static Value *peelAddrspacecast (Value *V) {
12680- if (auto *I = dyn_cast<AddrSpaceCastInst>(V))
12681- return peelAddrspacecast (I->getPointerOperand ());
12707+ static Value *peelAddrspacecast (Value *V, unsigned FlatAS) {
12708+ if (auto *I = dyn_cast<AddrSpaceCastInst>(V)) {
12709+ assert (I->getSrcAddressSpace () != FlatAS &&
12710+ " there should not be flat AS -> non-flat AS" );
12711+ return I->getPointerOperand ();
12712+ }
1268212713 if (auto *C = dyn_cast<ConstantExpr>(V))
12683- if (C->getOpcode () == Instruction::AddrSpaceCast)
12684- return peelAddrspacecast (C->getOperand (0 ));
12714+ if (C->getOpcode () == Instruction::AddrSpaceCast) {
12715+ assert (C->getOperand (0 )->getType ()->getPointerAddressSpace () !=
12716+ FlatAS &&
12717+ " there should not be flat AS -> non-flat AS X" );
12718+ return C->getOperand (0 );
12719+ }
1268512720 return V;
1268612721 }
1268712722};
0 commit comments