@@ -12583,16 +12583,36 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1258312583 }
1258412584
1258512585 ChangeStatus updateImpl (Attributor &A) override {
12586+ unsigned FlatAS = A.getInfoCache ().getFlatAddressSpace ().value ();
1258612587 uint32_t OldAddressSpace = AssumedAddressSpace;
12587- auto *AUO = A.getOrCreateAAFor <AAUnderlyingObjects>(getIRPosition (), this ,
12588- DepClassTy::REQUIRED);
12589- auto Pred = [&](Value &Obj) {
12588+
12589+ auto CheckAddressSpace = [&](Value &Obj) {
1259012590 if (isa<UndefValue>(&Obj))
1259112591 return true ;
12592+ // If an argument in flat address space only has addrspace cast uses, and
12593+ // those casts are same, then we take the dst addrspace.
12594+ if (auto *Arg = dyn_cast<Argument>(&Obj)) {
12595+ if (Arg->getType ()->getPointerAddressSpace () == FlatAS) {
12596+ unsigned CastAddrSpace = FlatAS;
12597+ for (auto *U : Arg->users ()) {
12598+ auto *ASCI = dyn_cast<AddrSpaceCastInst>(U);
12599+ if (!ASCI)
12600+ return takeAddressSpace (Obj.getType ()->getPointerAddressSpace ());
12601+ if (CastAddrSpace != FlatAS &&
12602+ CastAddrSpace != ASCI->getDestAddressSpace ())
12603+ return false ;
12604+ CastAddrSpace = ASCI->getDestAddressSpace ();
12605+ }
12606+ if (CastAddrSpace != FlatAS)
12607+ return takeAddressSpace (CastAddrSpace);
12608+ }
12609+ }
1259212610 return takeAddressSpace (Obj.getType ()->getPointerAddressSpace ());
1259312611 };
1259412612
12595- if (!AUO->forallUnderlyingObjects (Pred))
12613+ auto *AUO = A.getOrCreateAAFor <AAUnderlyingObjects>(getIRPosition (), this ,
12614+ DepClassTy::REQUIRED);
12615+ if (!AUO->forallUnderlyingObjects (CheckAddressSpace))
1259612616 return indicatePessimisticFixpoint ();
1259712617
1259812618 return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED
@@ -12601,17 +12621,21 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1260112621
1260212622 // / See AbstractAttribute::manifest(...).
1260312623 ChangeStatus manifest (Attributor &A) override {
12604- if (getAddressSpace () == InvalidAddressSpace ||
12605- getAddressSpace () == getAssociatedType ()->getPointerAddressSpace ())
12624+ unsigned NewAS = getAddressSpace ();
12625+
12626+ if (NewAS == InvalidAddressSpace ||
12627+ NewAS == getAssociatedType ()->getPointerAddressSpace ())
1260612628 return ChangeStatus::UNCHANGED;
1260712629
12630+ unsigned FlatAS = A.getInfoCache ().getFlatAddressSpace ().value ();
12631+
1260812632 Value *AssociatedValue = &getAssociatedValue ();
12609- Value *OriginalValue = peelAddrspacecast (AssociatedValue);
12633+ Value *OriginalValue = peelAddrspacecast (AssociatedValue, FlatAS );
1261012634
1261112635 PointerType *NewPtrTy =
12612- PointerType::get (getAssociatedType ()->getContext (), getAddressSpace () );
12636+ PointerType::get (getAssociatedType ()->getContext (), NewAS );
1261312637 bool UseOriginalValue =
12614- OriginalValue->getType ()->getPointerAddressSpace () == getAddressSpace () ;
12638+ OriginalValue->getType ()->getPointerAddressSpace () == NewAS ;
1261512639
1261612640 bool Changed = false ;
1261712641
@@ -12671,12 +12695,19 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1267112695 return AssumedAddressSpace == AS;
1267212696 }
1267312697
12674- static Value *peelAddrspacecast (Value *V) {
12675- if (auto *I = dyn_cast<AddrSpaceCastInst>(V))
12676- return peelAddrspacecast (I->getPointerOperand ());
12698+ static Value *peelAddrspacecast (Value *V, unsigned FlatAS) {
12699+ if (auto *I = dyn_cast<AddrSpaceCastInst>(V)) {
12700+ assert (I->getSrcAddressSpace () != FlatAS &&
12701+ " there should not be flat AS -> non-flat AS" );
12702+ return I->getPointerOperand ();
12703+ }
1267712704 if (auto *C = dyn_cast<ConstantExpr>(V))
12678- if (C->getOpcode () == Instruction::AddrSpaceCast)
12679- return peelAddrspacecast (C->getOperand (0 ));
12705+ if (C->getOpcode () == Instruction::AddrSpaceCast) {
12706+ assert (C->getOperand (0 )->getType ()->getPointerAddressSpace () !=
12707+ FlatAS &&
12708+ " there should not be flat AS -> non-flat AS X" );
12709+ return C->getOperand (0 );
12710+ }
1268012711 return V;
1268112712 }
1268212713};
0 commit comments