Skip to content

Commit 07d74a9

Browse files
ekochetkigcbot
authored andcommitted
Enable GAS resolution for arguments
When we can prove that pointer coming from kernel function arguments is not changed inside the kernel, we assume that it points to global address space
1 parent 78215b6 commit 07d74a9

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed

IGC/Compiler/CISACodeGen/ResolveGAS.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
3535
#include "llvm/Support/Debug.h"
3636
#include "llvmWrapper/IR/Constant.h"
3737
#include <llvm/ADT/DenseSet.h>
38+
#include <llvm/ADT/SmallVector.h>
3839
#include <llvm/ADT/PostOrderIterator.h>
3940
#include <llvm/Analysis/LoopInfo.h>
4041
#include <llvm/IR/IRBuilder.h>
@@ -79,6 +80,8 @@ namespace {
7980
void getAnalysisUsage(AnalysisUsage& AU) const override {
8081
AU.setPreservesCFG();
8182
AU.addRequired<LoopInfoWrapperPass>();
83+
AU.addRequired<AAResultsWrapperPass>();
84+
AU.addRequired<MetaDataUtilsWrapper>();
8285
}
8386

8487
bool isResolvableLoopPHI(PHINode* PN) const {
@@ -89,11 +92,17 @@ namespace {
8992
bool resolveOnFunction(Function*) const;
9093
bool resolveOnBasicBlock(BasicBlock*) const;
9194

95+
bool resolveMemoryFromHost(Function&) const;
96+
9297
void populateResolvableLoopPHIs();
9398
void populateResolvableLoopPHIsForLoop(const Loop*);
9499

95100
bool isAddrSpaceResolvable(PHINode* PN, const Loop* L,
96101
BasicBlock* BackEdge) const;
102+
103+
bool checkGenericArguments(Function& F) const;
104+
void convertLoadToGlobal(LoadInst* LI) const;
105+
bool isLoadGlobalCandidate(LoadInst* LI) const;
97106
};
98107

99108
class GASPropagator : public InstVisitor<GASPropagator, bool> {
@@ -148,6 +157,8 @@ namespace IGC {
148157
IGC_INITIALIZE_PASS_BEGIN(GASResolving, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY,
149158
PASS_ANALYSIS)
150159
IGC_INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
160+
IGC_INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
161+
IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
151162
IGC_INITIALIZE_PASS_END(GASResolving, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY,
152163
PASS_ANALYSIS)
153164
}
@@ -158,6 +169,8 @@ bool GASResolving::runOnFunction(Function& F) {
158169
IRB = &TheBuilder;
159170
Propagator = &ThePropagator;
160171

172+
resolveMemoryFromHost(F);
173+
161174
populateResolvableLoopPHIs();
162175

163176
bool Changed = false;
@@ -676,3 +689,173 @@ bool GASPropagator::visitCallInst(CallInst& I) {
676689

677690
return false;
678691
}
692+
693+
bool GASResolving::resolveMemoryFromHost(Function& F) const {
694+
MetaDataUtils* pMdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
695+
696+
// skip all non-entry functions
697+
if (!isEntryFunc(pMdUtils, &F))
698+
return false;
699+
700+
// early check in order not to iterate whole function
701+
if (!checkGenericArguments(F))
702+
return false;
703+
704+
SmallVector<StoreInst*, 32> Stores;
705+
SmallVector<LoadInst*, 32> Loads;
706+
AliasAnalysis* AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
707+
708+
// collect load candidates and in parallel check for unsafe instructions
709+
// visitor may be a more beautiful way to do this
710+
bool HasASCast = false; // if there exists addrspace cast from non global/generic space
711+
bool HasPtoi = false; // if there exists ptrtoint with global/generic space
712+
for (BasicBlock& B : F) {
713+
for (Instruction& I : B) {
714+
if (auto LI = dyn_cast<LoadInst>(&I)) {
715+
if (isLoadGlobalCandidate(LI)) {
716+
Loads.push_back(LI);
717+
}
718+
}
719+
else if (auto CI = dyn_cast<CallInst>(&I)) {
720+
if (CI->onlyReadsMemory())
721+
continue;
722+
723+
// currently recognize only these ones
724+
// in fact intrinsics should be marked as read-only
725+
if (auto II = dyn_cast<IntrinsicInst>(CI)) {
726+
if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
727+
II->getIntrinsicID() == Intrinsic::lifetime_end)
728+
continue;
729+
}
730+
731+
// if we have an unsafe call in the kernel, abort
732+
// to improve we can collect arguments of writing calls as memlocations for alias analysis
733+
return false;
734+
}
735+
else if (auto PI = dyn_cast<PtrToIntInst>(&I)) {
736+
// if we have a ptrtoint we need to check data flow which we don't want to
737+
if (PI->getPointerAddressSpace() != ADDRESS_SPACE_GLOBAL &&
738+
PI->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
739+
return false;
740+
else {
741+
HasPtoi = true;
742+
}
743+
744+
return false;
745+
}
746+
else if (auto AI = dyn_cast<AddrSpaceCastInst>(&I)) {
747+
if (AI->getSrcAddressSpace() != ADDRESS_SPACE_GLOBAL &&
748+
AI->getSrcAddressSpace() != ADDRESS_SPACE_GENERIC) {
749+
HasASCast = true;
750+
}
751+
}
752+
else if (auto SI = dyn_cast<StoreInst>(&I)) {
753+
Value* V = SI->getValueOperand();
754+
if (isa<PointerType>(V->getType())) {
755+
// this store can potentially write non-global pointer to memory
756+
Stores.push_back(SI);
757+
}
758+
}
759+
else if (I.mayWriteToMemory()) {
760+
// unsupported instruction poisoning memory
761+
return false;
762+
}
763+
}
764+
}
765+
if (HasASCast && HasPtoi)
766+
return false;
767+
768+
if (Loads.empty())
769+
return false;
770+
771+
bool Changed = false;
772+
while (!Loads.empty())
773+
{
774+
LoadInst* LI = Loads.pop_back_val();
775+
776+
// check that we don't have aliasing stores for this load
777+
// we expect to have basic and addrspace AA available at the moment
778+
// on optimization phase
779+
bool aliases = false;
780+
for (auto SI : Stores) {
781+
if (AA->alias(MemoryLocation::get(SI), MemoryLocation::get(LI))) {
782+
aliases = true;
783+
break;
784+
}
785+
}
786+
if (aliases)
787+
continue;
788+
789+
convertLoadToGlobal(LI);
790+
Changed = true;
791+
}
792+
return Changed;
793+
}
794+
795+
bool GASResolving::isLoadGlobalCandidate(LoadInst* LI) const {
796+
// first check that loaded address has generic address space
797+
// otherwise it is not our candidate
798+
PointerType* PtrTy = dyn_cast<PointerType>(LI->getType());
799+
if (!PtrTy || PtrTy->getAddressSpace() != ADDRESS_SPACE_GENERIC)
800+
return false;
801+
802+
// next check that it is a load from function argument + offset
803+
// which is necessary to prove that this address has global addrspace
804+
Value* LoadBase = LI->getPointerOperand()->stripInBoundsOffsets();
805+
if (!isa<Argument>(LoadBase))
806+
return false;
807+
808+
// don't want to process cases when argument is from local address space
809+
auto LoadTy = cast<PointerType>(LoadBase->getType());
810+
if (LoadTy->getAddressSpace() != ADDRESS_SPACE_GLOBAL)
811+
return false;
812+
813+
// TODO: skip cases that have been fixed on previous traversals
814+
815+
return true;
816+
}
817+
818+
void GASResolving::convertLoadToGlobal(LoadInst* LI) const {
819+
// create two addressspace casts: generic -> global -> generic
820+
// the next scalar phase of this pass will propagate global to all uses of the load
821+
822+
PointerType* PtrTy = cast<PointerType>(LI->getType());
823+
IRB->SetInsertPoint(LI->getNextNode());
824+
PointerType* GlobalPtrTy = PointerType::get(PtrTy->getElementType(), ADDRESS_SPACE_GLOBAL);
825+
Value* GlobalAddr = IRB->CreateAddrSpaceCast(LI, GlobalPtrTy);
826+
Value* GenericCopyAddr = IRB->CreateAddrSpaceCast(GlobalAddr, PtrTy);
827+
828+
for (auto UI = LI->use_begin(), UE = LI->use_end(); UI != UE; /*EMPTY*/) {
829+
Use& U = *UI++;
830+
if (U.getUser() == GlobalAddr)
831+
continue;
832+
U.set(GenericCopyAddr);
833+
}
834+
}
835+
836+
bool GASResolving::checkGenericArguments(Function& F) const {
837+
// check that we have a pointer to pointer or pointer to struct that has pointer elements
838+
// and main pointer type is global while underlying pointer type is generic
839+
840+
auto* FT = F.getFunctionType();
841+
for (unsigned p = 0; p < FT->getNumParams(); ++p) {
842+
if (auto Ty = dyn_cast<PointerType>(FT->getParamType(p))) {
843+
if (Ty->getAddressSpace() != ADDRESS_SPACE_GLOBAL)
844+
continue;
845+
auto PteeTy = Ty->getElementType();
846+
if (auto PTy = dyn_cast<PointerType>(PteeTy)) {
847+
if (PTy->getAddressSpace() == ADDRESS_SPACE_GENERIC)
848+
return true;
849+
}
850+
if (auto STy = dyn_cast<StructType>(PteeTy)) {
851+
for (unsigned e = 0; e < STy->getNumElements(); ++e) {
852+
if (auto ETy = dyn_cast<PointerType>(STy->getElementType(e))) {
853+
if (ETy->getAddressSpace() == ADDRESS_SPACE_GENERIC)
854+
return true;
855+
}
856+
}
857+
}
858+
}
859+
}
860+
return false;
861+
}

0 commit comments

Comments
 (0)