@@ -112,10 +112,7 @@ static bool isPtrArgModified(const Value &Arg) {
112112}
113113
114114// Check if it is safe to pass structure by value.
115- static bool structSafeToPassByVal (const Argument &Arg) {
116- StructType *StrTy =
117- cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy (Arg.getType ()));
118-
115+ static bool structSafeToPassByVal (const Argument &Arg, StructType *StrTy) {
119116 if (!containsOnlySuitableTypes (*StrTy))
120117 return false ;
121118
@@ -149,13 +146,99 @@ static bool structSafeToPassByVal(const Argument &Arg) {
149146 return llvm::all_of (Arg.users (), UserChecker) && !isPtrArgModified (Arg);
150147}
151148
149+ static Type *getPtrArgElementType (const Argument &PtrArg) {
150+ auto *PtrArgTy = cast<PointerType>(PtrArg.getType ());
151+ if (!PtrArgTy->isOpaque ())
152+ return IGCLLVM::getNonOpaquePtrEltTy (PtrArgTy);
153+ if (auto *ByValTy = PtrArg.getParamByValType ())
154+ return ByValTy;
155+ if (auto *StructRetTy = PtrArg.getParamStructRetType ())
156+ return StructRetTy;
157+ SmallPtrSet<Type *, 2 > ElemTys;
158+ for (auto *U : PtrArg.users ()) {
159+ if (ElemTys.size () > 1 )
160+ return nullptr ;
161+ auto *I = dyn_cast<Instruction>(U);
162+ if (!I)
163+ continue ;
164+ if (auto *LI = dyn_cast<LoadInst>(I)) {
165+ if (&PtrArg == LI->getPointerOperand ())
166+ ElemTys.insert (LI->getType ());
167+ } else if (auto *SI = dyn_cast<StoreInst>(I)) {
168+ if (&PtrArg == SI->getPointerOperand ())
169+ ElemTys.insert (SI->getValueOperand ()->getType ());
170+ } else if (auto *GEPI = dyn_cast<GetElementPtrInst>(I)) {
171+ if (&PtrArg == GEPI->getPointerOperand ())
172+ ElemTys.insert (GEPI->getSourceElementType ());
173+ } else if (auto *PTII = dyn_cast<PtrToIntInst>(I)) {
174+ const Value *Addr = PTII;
175+ for (auto *AddrUser : Addr->users ()) {
176+ switch (GenXIntrinsic::getAnyIntrinsicID (AddrUser)) {
177+ case GenXIntrinsic::genx_gather_scaled:
178+ if (Addr != AddrUser->getOperand (4 ))
179+ continue ;
180+ ElemTys.insert (AddrUser->getType ());
181+ break ;
182+ case GenXIntrinsic::genx_scatter_scaled:
183+ if (Addr != AddrUser->getOperand (4 ))
184+ continue ;
185+ ElemTys.insert (AddrUser->getOperand (6 )->getType ());
186+ break ;
187+ case GenXIntrinsic::genx_svm_block_ld:
188+ case GenXIntrinsic::genx_svm_block_ld_unaligned:
189+ if (Addr != AddrUser->getOperand (0 ))
190+ continue ;
191+ ElemTys.insert (AddrUser->getType ());
192+ break ;
193+ case GenXIntrinsic::genx_svm_block_st:
194+ if (Addr != AddrUser->getOperand (0 ))
195+ continue ;
196+ ElemTys.insert (AddrUser->getOperand (1 )->getType ());
197+ break ;
198+ default :
199+ break ;
200+ }
201+ }
202+ if (!PTII->hasOneUse ())
203+ continue ;
204+ auto *IEI = dyn_cast<InsertElementInst>(PTII->user_back ());
205+ if (!IEI)
206+ continue ;
207+ const Value *AddrVec = IEI;
208+ if (IEI->hasOneUse ())
209+ if (auto *SVI = dyn_cast<ShuffleVectorInst>(IEI->user_back ()))
210+ if (SVI->hasOneUse ())
211+ if (auto *BO = dyn_cast<BinaryOperator>(SVI->user_back ()))
212+ AddrVec = BO;
213+ for (auto *AddrVecUser : AddrVec->users ()) {
214+ switch (GenXIntrinsic::getAnyIntrinsicID (AddrVecUser)) {
215+ case GenXIntrinsic::genx_svm_gather:
216+ if (AddrVec != AddrVecUser->getOperand (2 ))
217+ continue ;
218+ ElemTys.insert (AddrVecUser->getType ());
219+ break ;
220+ case GenXIntrinsic::genx_svm_scatter:
221+ if (AddrVec != AddrVecUser->getOperand (2 ))
222+ continue ;
223+ ElemTys.insert (AddrVecUser->getOperand (3 )->getType ());
224+ break ;
225+ default :
226+ break ;
227+ }
228+ }
229+ }
230+ }
231+ return ElemTys.empty () ? nullptr : *ElemTys.begin ();
232+ }
233+
152234// Check if argument should be transformed.
153- static bool argToTransform (const Argument &Arg,
154- vc::TypeSizeWrapper MaxStructSize) {
155- auto *PtrTy = dyn_cast<PointerType>(Arg.getType ());
156- if (!PtrTy)
157- return false ;
158- Type *ElemTy = IGCLLVM::getNonOpaquePtrEltTy (PtrTy);
235+ static Type *argToTransform (const Argument &Arg,
236+ vc::TypeSizeWrapper MaxStructSize) {
237+ if (!isa<PointerType>(Arg.getType ()))
238+ return nullptr ;
239+ auto *ElemTy = getPtrArgElementType (Arg);
240+ if (!ElemTy)
241+ return nullptr ;
159242 if (ElemTy->isIntOrIntVectorTy () || ElemTy->isFPOrFPVectorTy ()) {
160243 if (ElemTy->isVectorTy ()) {
161244 for (auto *U : Arg.users ()) {
@@ -166,28 +249,28 @@ static bool argToTransform(const Argument &Arg,
166249 continue ;
167250 auto *ConstIdx = dyn_cast<ConstantInt>(*GEP->idx_begin ());
168251 if (!ConstIdx || ConstIdx->getZExtValue () != 0 )
169- return false ;
252+ return nullptr ;
170253 }
171- return true ;
172- }
173- return onlyUsedBySimpleValueLoadStore (Arg) ;
254+ } else if (! onlyUsedBySimpleValueLoadStore (Arg))
255+ return nullptr ;
256+ return ElemTy ;
174257 }
175258 if (auto *StrTy = dyn_cast<StructType>(ElemTy)) {
176259 const DataLayout &DL = Arg.getParent ()->getParent ()->getDataLayout ();
177- if (structSafeToPassByVal (Arg) &&
260+ if (structSafeToPassByVal (Arg, StrTy ) &&
178261 vc::getTypeSize (StrTy, &DL) <= MaxStructSize)
179- return true ;
262+ return ElemTy ;
180263 }
181- return false ;
264+ return nullptr ;
182265}
183266
184267// Collect arguments that should be transformed.
185- SmallPtrSet <Argument *, 8 >
268+ SmallDenseMap <Argument *, Type * >
186269vc::collectArgsToTransform (Function &F, vc::TypeSizeWrapper MaxStructSize) {
187- SmallPtrSet <Argument *, 8 > ArgsToTransform;
270+ SmallDenseMap <Argument *, Type * > ArgsToTransform;
188271 for (auto &Arg : F.args ())
189- if (argToTransform (Arg, MaxStructSize))
190- ArgsToTransform.insert (&Arg);
272+ if (auto *ArgElemTy = argToTransform (Arg, MaxStructSize))
273+ ArgsToTransform.insert (std::make_pair ( &Arg, ArgElemTy) );
191274 return ArgsToTransform;
192275}
193276
@@ -286,7 +369,7 @@ int vc::OrigArgInfo::getNewIdx() const {
286369}
287370
288371vc::TransformedFuncInfo::TransformedFuncInfo (
289- Function &OrigFunc, SmallPtrSetImpl <Argument *> &ArgsToTransform) {
372+ Function &OrigFunc, SmallDenseMap <Argument *, Type *> &ArgsToTransform) {
290373 fillOrigArgInfo (OrigFunc, ArgsToTransform);
291374 inheritAttributes (OrigFunc);
292375
@@ -335,7 +418,7 @@ void vc::TransformedFuncInfo::appendGlobals(
335418}
336419
337420void vc::TransformedFuncInfo::fillOrigArgInfo (
338- Function &OrigFunc, SmallPtrSetImpl <Argument *> &ArgsToTransform) {
421+ Function &OrigFunc, SmallDenseMap <Argument *, Type *> &ArgsToTransform) {
339422 IGC_ASSERT_MESSAGE (OrigArgs.empty (),
340423 " shouldn't be filled before this method" );
341424
@@ -358,7 +441,9 @@ void vc::TransformedFuncInfo::fillOrigArgInfo(
358441
359442 // Update type for transformed arguments.
360443 if (Kind != ArgKind::General) {
361- Ty = IGCLLVM::getNonOpaquePtrEltTy (Ty);
444+ auto It = ArgsToTransform.find (&Arg);
445+ IGC_ASSERT_EXIT (It != ArgsToTransform.end ());
446+ Ty = It->second ;
362447 }
363448
364449 if (Kind == ArgKind::CopyOut) {
@@ -489,8 +574,8 @@ getTransformedFuncCallArgs(CallInst &OrigCall,
489574 IGC_ASSERT_MESSAGE (Kind == ArgKind::CopyIn || Kind == ArgKind::CopyInOut,
490575 " unexpected arg kind" );
491576 LoadInst *Load =
492- new LoadInst (IGCLLVM::getNonOpaquePtrEltTy ( OrigArg.get ()-> getType () ),
493- OrigArg.get (), OrigArg. get () ->getName () + " .val" ,
577+ new LoadInst (OrigArgData. getTransformedOrigType (), OrigArg.get (),
578+ OrigArg.get ()->getName () + " .val" ,
494579 /* isVolatile */ false , &OrigCall);
495580 NewCallOps.push_back (Load);
496581 break ;
0 commit comments