@@ -12596,16 +12596,37 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1259612596 }
1259712597
1259812598 ChangeStatus updateImpl (Attributor &A) override {
12599+ assert (A.getInfoCache ().getFlatAddressSpace ().has_value ());
12600+ unsigned FlatAS = A.getInfoCache ().getFlatAddressSpace ().value ();
1259912601 uint32_t OldAddressSpace = AssumedAddressSpace;
12600- auto *AUO = A.getOrCreateAAFor <AAUnderlyingObjects>(getIRPosition (), this ,
12601- DepClassTy::REQUIRED);
12602- auto Pred = [&](Value &Obj) {
12602+
12603+ auto CheckAddressSpace = [&](Value &Obj) {
1260312604 if (isa<UndefValue>(&Obj))
1260412605 return true ;
12606+ // If an argument in flat address space only has addrspace cast uses, and
12607+ // those casts are same, then we take the dst addrspace.
12608+ if (auto *Arg = dyn_cast<Argument>(&Obj)) {
12609+ if (Arg->getType ()->getPointerAddressSpace () == FlatAS) {
12610+ unsigned CastAddrSpace = FlatAS;
12611+ for (auto *U : Arg->users ()) {
12612+ auto *ASCI = dyn_cast<AddrSpaceCastInst>(U);
12613+ if (!ASCI)
12614+ return takeAddressSpace (Obj.getType ()->getPointerAddressSpace ());
12615+ if (CastAddrSpace != FlatAS &&
12616+ CastAddrSpace != ASCI->getDestAddressSpace ())
12617+ return false ;
12618+ CastAddrSpace = ASCI->getDestAddressSpace ();
12619+ }
12620+ if (CastAddrSpace != FlatAS)
12621+ return takeAddressSpace (CastAddrSpace);
12622+ }
12623+ }
1260512624 return takeAddressSpace (Obj.getType ()->getPointerAddressSpace ());
1260612625 };
1260712626
12608- if (!AUO->forallUnderlyingObjects (Pred))
12627+ auto *AUO = A.getOrCreateAAFor <AAUnderlyingObjects>(getIRPosition (), this ,
12628+ DepClassTy::REQUIRED);
12629+ if (!AUO->forallUnderlyingObjects (CheckAddressSpace))
1260912630 return indicatePessimisticFixpoint ();
1261012631
1261112632 return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED
@@ -12614,17 +12635,21 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1261412635
1261512636 // / See AbstractAttribute::manifest(...).
1261612637 ChangeStatus manifest (Attributor &A) override {
12617- if (getAddressSpace () == InvalidAddressSpace ||
12618- getAddressSpace () == getAssociatedType ()->getPointerAddressSpace ())
12638+ unsigned NewAS = getAddressSpace ();
12639+
12640+ if (NewAS == InvalidAddressSpace ||
12641+ NewAS == getAssociatedType ()->getPointerAddressSpace ())
1261912642 return ChangeStatus::UNCHANGED;
1262012643
12644+ unsigned FlatAS = A.getInfoCache ().getFlatAddressSpace ().value ();
12645+
1262112646 Value *AssociatedValue = &getAssociatedValue ();
12622- Value *OriginalValue = peelAddrspacecast (AssociatedValue);
12647+ Value *OriginalValue = peelAddrspacecast (AssociatedValue, FlatAS );
1262312648
1262412649 PointerType *NewPtrTy =
12625- PointerType::get (getAssociatedType ()->getContext (), getAddressSpace () );
12650+ PointerType::get (getAssociatedType ()->getContext (), NewAS );
1262612651 bool UseOriginalValue =
12627- OriginalValue->getType ()->getPointerAddressSpace () == getAddressSpace () ;
12652+ OriginalValue->getType ()->getPointerAddressSpace () == NewAS ;
1262812653
1262912654 bool Changed = false ;
1263012655
@@ -12684,12 +12709,19 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
1268412709 return AssumedAddressSpace == AS;
1268512710 }
1268612711
12687- static Value *peelAddrspacecast (Value *V) {
12688- if (auto *I = dyn_cast<AddrSpaceCastInst>(V))
12689- return peelAddrspacecast (I->getPointerOperand ());
12712+ static Value *peelAddrspacecast (Value *V, unsigned FlatAS) {
12713+ if (auto *I = dyn_cast<AddrSpaceCastInst>(V)) {
12714+ assert (I->getSrcAddressSpace () != FlatAS &&
12715+ " there should not be flat AS -> non-flat AS" );
12716+ return I->getPointerOperand ();
12717+ }
1269012718 if (auto *C = dyn_cast<ConstantExpr>(V))
12691- if (C->getOpcode () == Instruction::AddrSpaceCast)
12692- return peelAddrspacecast (C->getOperand (0 ));
12719+ if (C->getOpcode () == Instruction::AddrSpaceCast) {
12720+ assert (C->getOperand (0 )->getType ()->getPointerAddressSpace () !=
12721+ FlatAS &&
12722+ " there should not be flat AS -> non-flat AS X" );
12723+ return C->getOperand (0 );
12724+ }
1269312725 return V;
1269412726 }
1269512727};
0 commit comments