@@ -139,11 +139,9 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
139139}
140140
141141// Returns true for structs smaller than 'structSize' and only contains primitive types
142- inline bool isLegalStructType (const Module& M, Type* ty , unsigned structSize)
142+ inline bool isLegalStructType (const Module& M, StructType* sTy , unsigned structSize)
143143{
144- IGC_ASSERT (ty->isStructTy ());
145144 const DataLayout& DL = M.getDataLayout ();
146- StructType* sTy = dyn_cast<StructType>(ty);
147145 if (sTy && DL.getStructLayout (sTy )->getSizeInBits () <= structSize)
148146 {
149147 for (const auto * EltTy : sTy ->elements ())
@@ -163,7 +161,7 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
163161 {
164162 if (ty->isStructTy ())
165163 {
166- return isLegalStructType (M, ty , MAX_STRUCT_SIZE_IN_BITS);
164+ return isLegalStructType (M, cast<StructType>(ty) , MAX_STRUCT_SIZE_IN_BITS);
167165 }
168166 else if (ty->isArrayTy ())
169167 {
@@ -174,16 +172,14 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
174172 return true ;
175173}
176174
177- // Check if a struct pointer argument is promotable to pass-by-value
178- inline bool isPromotableStructType (const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false )
175+ inline bool isPromotableStructType (const Module& M, Type* pointeeType, bool isStackCall)
179176{
180177 if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
181178 return false ;
182-
183179 const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
184- if (ty-> isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)-> isStructTy ( ))
180+ if (isa<StructType>(pointeeType ))
185181 {
186- return isLegalStructType (M, IGCLLVM::getNonOpaquePtrEltTy (ty ), maxSize);
182+ return isLegalStructType (M, cast<StructType>(pointeeType ), maxSize);
187183 }
188184 return false ;
189185}
@@ -194,18 +190,29 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
194190 if (F->getReturnType ()->isVoidTy () &&
195191 !F->arg_empty () &&
196192 F->arg_begin ()->hasStructRetAttr () &&
197- isPromotableStructType (M, F->arg_begin ()->getType (), F->hasFnAttribute (" visaStackCall" ), true ))
193+ isPromotableStructType (M, F->arg_begin ()->getParamStructRetType (), F->hasFnAttribute (" visaStackCall" )))
198194 {
199195 return true ;
200196 }
201197 return false ;
202198}
203199
204200// Promotes struct pointer to struct type
205- inline Type * PromotedStructValueType (const Module& M, const Type* ty )
201+ inline StructType * PromotedStructValueType (const Module& M, const Argument* arg )
206202{
207- IGC_ASSERT (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ());
208- return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy (ty));
203+ if (arg->getType ()->isPointerTy ())
204+ {
205+ if (arg->hasStructRetAttr () && arg->getParamStructRetType ()->isStructTy ())
206+ {
207+ return cast<StructType>(arg->getParamStructRetType ());
208+ }
209+ else if (arg->hasByValAttr () && arg->getParamByValType ()->isStructTy ())
210+ {
211+ return cast<StructType>(arg->getParamByValType ());
212+ }
213+ }
214+ IGC_ASSERT_MESSAGE (0 , " Not implemented case" );
215+ return nullptr ;
209216}
210217
211218// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
@@ -218,7 +225,7 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* s
218225 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
219226 {
220227 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
221- Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
228+ Value* elementPtr = builder.CreateInBoundsGEP (strVal-> getType (), strPtr, indices);
222229 Value* element = builder.CreateExtractValue (strVal, i);
223230 builder.CreateStore (element, elementPtr);
224231 }
@@ -235,7 +242,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
235242 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
236243 {
237244 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
238- Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
245+ Value* elementPtr = builder.CreateInBoundsGEP (ty, strPtr, indices);
239246 Value* element = builder.CreateLoad (sTy ->getElementType (i), elementPtr);
240247 strVal = builder.CreateInsertValue (strVal, element, i);
241248 }
@@ -308,10 +315,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
308315 argTypes.push_back (LegalizedIntVectorType (M, ai->getType ()));
309316 }
310317 else if (ai->hasByValAttr () &&
311- isPromotableStructType (M, ai->getType (), isStackCall))
318+ isPromotableStructType (M, ai->getParamByValType (), isStackCall))
312319 {
313320 fixArgType = true ;
314- argTypes.push_back (PromotedStructValueType (M, ai-> getType () ));
321+ argTypes.push_back (PromotedStructValueType (M, ai));
315322 }
316323 else if (!isLegalSignatureType (M, ai->getType (), isStackCall))
317324 {
@@ -329,7 +336,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
329336 // Clone function with new signature
330337 Type* returnType =
331338 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (M.getContext ()) :
332- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()-> getType () ) :
339+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()) :
333340 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, pFunc->getReturnType ()) :
334341 pFunc->getReturnType ();
335342 FunctionType* signature = FunctionType::get (returnType, argTypes, false );
@@ -393,13 +400,12 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
393400 if (OldArgIt == pFunc->arg_begin () && retTypeOption == ReturnOpt::RETURN_STRUCT)
394401 {
395402 // Create a temp alloca to map the old argument. This will be removed later by SROA.
396- tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt-> getType () );
403+ tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt);
397404 tempAllocaForSRetPointer = builder.CreateAlloca (tempAllocaForSRetPointerTy);
398405 tempAllocaForSRetPointer = builder.CreateAddrSpaceCast (tempAllocaForSRetPointer, OldArgIt->getType ());
399406 VMap[&*OldArgIt] = tempAllocaForSRetPointer;
400407 continue ;
401408 }
402-
403409 NewArgIt->setName (OldArgIt->getName ());
404410 if (!isLegalIntVectorType (M, OldArgIt->getType ()))
405411 {
@@ -408,24 +414,25 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
408414 VMap[&*OldArgIt] = trunc;
409415 }
410416 else if (OldArgIt->hasByValAttr () &&
411- isPromotableStructType (M, OldArgIt->getType (), isStackCall))
417+ isPromotableStructType (M, OldArgIt->getParamByValType (), isStackCall))
412418 {
419+ AllocaInst* newArgPtr = builder.CreateAlloca (OldArgIt->getParamByValType ());
413420 // remove "byval" attrib since it is now pass-by-value
414421 NewArgIt->removeAttr (llvm::Attribute::ByVal);
415- Value* newArgPtr = builder.CreateAlloca (NewArgIt->getType ());
416422 StoreToStruct (builder, &*NewArgIt, newArgPtr);
417423 // cast back to original addrspace
418424 IGC_ASSERT (OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_GENERIC ||
419- OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
420- newArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
421- VMap[&*OldArgIt] = newArgPtr ;
425+ OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
426+ llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
427+ VMap[&*OldArgIt] = castedNewArgPtr ;
422428 }
423429 else if (!isLegalSignatureType (M, OldArgIt->getType (), isStackCall))
424430 {
425431 // Load from pointer arg
426- Value* load = builder.CreateLoad (&*NewArgIt);
432+ Value* load = builder.CreateLoad (OldArgIt-> getType (), &*NewArgIt);
427433 VMap[&*OldArgIt] = load;
428- ArgByVal.push_back (&*NewArgIt);
434+ llvm::Attribute byValAttr = llvm::Attribute::getWithByValType (M.getContext (), OldArgIt->getType ());
435+ NewArgIt->addAttr (byValAttr);
429436 }
430437 else
431438 {
@@ -444,21 +451,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
444451 builder.CreateBr (ClonedEntryBB);
445452 MergeBlockIntoPredecessor (ClonedEntryBB);
446453
447- // Loop through new args and add 'byval' attributes
448- for (auto arg : ArgByVal)
449- {
450- arg->addAttr (llvm::Attribute::getWithByValType (M.getContext (),
451- IGCLLVM::getNonOpaquePtrEltTy (arg->getType ())));
452- }
453-
454454 // Now fix the return values
455455 if (retTypeOption == ReturnOpt::RETURN_BY_REF)
456456 {
457457 // Add the 'noalias' and 'sret' attribute to arg0
458458 auto retArg = pNewFunc->arg_begin ();
459459 retArg->addAttr (llvm::Attribute::NoAlias);
460- retArg->addAttr (llvm::Attribute::getWithStructRetType (
461- M.getContext (), IGCLLVM::getNonOpaquePtrEltTy (retArg->getType ())));
460+ retArg->addAttr (llvm::Attribute::getWithStructRetType (M.getContext (), pFunc->getReturnType ()));
462461
463462 // Loop through all return instructions and store the old return value into the arg0 pointer
464463 const auto ptrSize = DL.getPointerSize ();
@@ -577,7 +576,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
577576 if (callInst->getType ()->isVoidTy () &&
578577 IGCLLVM::getNumArgOperands (callInst) > 0 &&
579578 callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
580- isPromotableStructType (M, callInst->getArgOperand ( 0 )-> getType (), isStackCall, true /* retval */ ))
579+ isPromotableStructType (M, callInst->getParamAttr ( 0 , llvm::Attribute::StructRet). getValueAsType (), isStackCall))
581580 {
582581 opNum++; // Skip the first call operand
583582 retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -608,18 +607,17 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
608607 {
609608 // extend the illegal int to a legal type
610609 IGCLLVM::IRBuilder<> builder (callInst);
611- Value* extend = builder.CreateZExt (arg , LegalizedIntVectorType (M, arg->getType ()));
610+ Value* extend = builder.CreateZExt (callInst-> getOperand (opNum) , LegalizedIntVectorType (M, arg->getType ()));
612611 callArgs.push_back (extend);
613612 ArgAttrVec.push_back (AttributeSet ());
614613 fixArgType = true ;
615614 }
616615 else if (callInst->paramHasAttr (opNum, llvm::Attribute::ByVal) &&
617- isPromotableStructType (M, arg-> getType ( ), isStackCall))
616+ isPromotableStructType (M, callInst-> getParamByValType (opNum ), isStackCall))
618617 {
619618 // Map the new operand to the loaded value of the struct pointer
620619 IGCLLVM::IRBuilder<> builder (callInst);
621- Argument* callArg = IGCLLVM::getArg (*calledFunc, opNum);
622- Value* newOp = LoadFromStruct (builder, arg, callArg->getParamByValType ());
620+ Value* newOp = LoadFromStruct (builder, callInst->getOperand (opNum), callInst->getParamByValType (opNum));
623621 callArgs.push_back (newOp);
624622 ArgAttrVec.push_back (AttributeSet ());
625623 fixArgType = true ;
@@ -629,7 +627,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
629627 // Create and store operand as an alloca, then pass as argument
630628 IGCLLVM::IRBuilder<> builder (callInst);
631629 Value* allocaV = builder.CreateAlloca (arg->getType ());
632- builder.CreateStore (arg , allocaV);
630+ builder.CreateStore (callInst-> getOperand (opNum) , allocaV);
633631 callArgs.push_back (allocaV);
634632 auto byValAttr = llvm::Attribute::getWithByValType (M.getContext (), arg->getType ());
635633 auto argAttrs = AttributeSet::get (M.getContext (), { byValAttr });
@@ -659,7 +657,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
659657 }
660658 Type* retType =
661659 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (callInst->getContext ()) :
662- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getArgOperand ( 0 )->getType ( )) :
660+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getFunction ( )->getArg ( 0 )) :
663661 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, callInst->getType ()) :
664662 callInst->getType ();
665663 newFnTy = FunctionType::get (retType, argTypes, false );
0 commit comments