@@ -139,9 +139,11 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
139139}
140140
141141// Returns true for structs smaller than 'structSize' and only contains primitive types
142- inline bool isLegalStructType (const Module& M, StructType* sTy , unsigned structSize)
142+ inline bool isLegalStructType (const Module& M, Type* ty , unsigned structSize)
143143{
144+ IGC_ASSERT (ty->isStructTy ());
144145 const DataLayout& DL = M.getDataLayout ();
146+ StructType* sTy = dyn_cast<StructType>(ty);
145147 if (sTy && DL.getStructLayout (sTy )->getSizeInBits () <= structSize)
146148 {
147149 for (const auto * EltTy : sTy ->elements ())
@@ -161,7 +163,7 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
161163 {
162164 if (ty->isStructTy ())
163165 {
164- return isLegalStructType (M, cast<StructType>(ty) , MAX_STRUCT_SIZE_IN_BITS);
166+ return isLegalStructType (M, ty , MAX_STRUCT_SIZE_IN_BITS);
165167 }
166168 else if (ty->isArrayTy ())
167169 {
@@ -172,14 +174,16 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
172174 return true ;
173175}
174176
175- inline bool isPromotableStructType (const Module& M, Type* pointeeType, bool isStackCall)
177+ // Check if a struct pointer argument is promotable to pass-by-value
178+ inline bool isPromotableStructType (const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false )
176179{
177180 if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
178181 return false ;
182+
179183 const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
180- if (isa<StructType>(pointeeType ))
184+ if (ty-> isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)-> isStructTy ( ))
181185 {
182- return isLegalStructType (M, cast<StructType>(pointeeType ), maxSize);
186+ return isLegalStructType (M, IGCLLVM::getNonOpaquePtrEltTy (ty ), maxSize);
183187 }
184188 return false ;
185189}
@@ -190,29 +194,18 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
190194 if (F->getReturnType ()->isVoidTy () &&
191195 !F->arg_empty () &&
192196 F->arg_begin ()->hasStructRetAttr () &&
193- isPromotableStructType (M, F->arg_begin ()->getParamStructRetType (), F->hasFnAttribute (" visaStackCall" )))
197+ isPromotableStructType (M, F->arg_begin ()->getType (), F->hasFnAttribute (" visaStackCall" ), true ))
194198 {
195199 return true ;
196200 }
197201 return false ;
198202}
199203
200204// Promotes struct pointer to struct type
201- inline StructType * PromotedStructValueType (const Module& M, const Argument* arg )
205+ inline Type * PromotedStructValueType (const Module& M, const Type* ty )
202206{
203- if (arg->getType ()->isPointerTy ())
204- {
205- if (arg->hasStructRetAttr () && arg->getParamStructRetType ()->isStructTy ())
206- {
207- return cast<StructType>(arg->getParamStructRetType ());
208- }
209- else if (arg->hasByValAttr () && arg->getParamByValType ()->isStructTy ())
210- {
211- return cast<StructType>(arg->getParamByValType ());
212- }
213- }
214- IGC_ASSERT_MESSAGE (0 , " Not implemented case" );
215- return nullptr ;
207+ IGC_ASSERT (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ());
208+ return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy (ty));
216209}
217210
218211// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
@@ -225,7 +218,7 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* s
225218 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
226219 {
227220 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
228- Value* elementPtr = builder.CreateInBoundsGEP (strVal-> getType (), strPtr, indices);
221+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
229222 Value* element = builder.CreateExtractValue (strVal, i);
230223 builder.CreateStore (element, elementPtr);
231224 }
@@ -242,7 +235,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
242235 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
243236 {
244237 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
245- Value* elementPtr = builder.CreateInBoundsGEP (ty, strPtr, indices);
238+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
246239 Value* element = builder.CreateLoad (sTy ->getElementType (i), elementPtr);
247240 strVal = builder.CreateInsertValue (strVal, element, i);
248241 }
@@ -315,10 +308,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
315308 argTypes.push_back (LegalizedIntVectorType (M, ai->getType ()));
316309 }
317310 else if (ai->hasByValAttr () &&
318- isPromotableStructType (M, ai->getParamByValType (), isStackCall))
311+ isPromotableStructType (M, ai->getType (), isStackCall))
319312 {
320313 fixArgType = true ;
321- argTypes.push_back (PromotedStructValueType (M, ai));
314+ argTypes.push_back (PromotedStructValueType (M, ai-> getType () ));
322315 }
323316 else if (!isLegalSignatureType (M, ai->getType (), isStackCall))
324317 {
@@ -336,7 +329,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
336329 // Clone function with new signature
337330 Type* returnType =
338331 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (M.getContext ()) :
339- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()) :
332+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()-> getType () ) :
340333 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, pFunc->getReturnType ()) :
341334 pFunc->getReturnType ();
342335 FunctionType* signature = FunctionType::get (returnType, argTypes, false );
@@ -400,12 +393,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
400393 if (OldArgIt == pFunc->arg_begin () && retTypeOption == ReturnOpt::RETURN_STRUCT)
401394 {
402395 // Create a temp alloca to map the old argument. This will be removed later by SROA.
403- tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt);
396+ tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt-> getType () );
404397 tempAllocaForSRetPointer = builder.CreateAlloca (tempAllocaForSRetPointerTy);
405398 tempAllocaForSRetPointer = builder.CreateAddrSpaceCast (tempAllocaForSRetPointer, OldArgIt->getType ());
406399 VMap[&*OldArgIt] = tempAllocaForSRetPointer;
407400 continue ;
408401 }
402+
409403 NewArgIt->setName (OldArgIt->getName ());
410404 if (!isLegalIntVectorType (M, OldArgIt->getType ()))
411405 {
@@ -414,25 +408,24 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
414408 VMap[&*OldArgIt] = trunc;
415409 }
416410 else if (OldArgIt->hasByValAttr () &&
417- isPromotableStructType (M, OldArgIt->getParamByValType (), isStackCall))
411+ isPromotableStructType (M, OldArgIt->getType (), isStackCall))
418412 {
419- AllocaInst* newArgPtr = builder.CreateAlloca (OldArgIt->getParamByValType ());
420413 // remove "byval" attrib since it is now pass-by-value
421414 NewArgIt->removeAttr (llvm::Attribute::ByVal);
415+ Value* newArgPtr = builder.CreateAlloca (NewArgIt->getType ());
422416 StoreToStruct (builder, &*NewArgIt, newArgPtr);
423417 // cast back to original addrspace
424418 IGC_ASSERT (OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_GENERIC ||
425- OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
426- llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
427- VMap[&*OldArgIt] = castedNewArgPtr ;
419+ OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
420+ newArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
421+ VMap[&*OldArgIt] = newArgPtr ;
428422 }
429423 else if (!isLegalSignatureType (M, OldArgIt->getType (), isStackCall))
430424 {
431425 // Load from pointer arg
432- Value* load = builder.CreateLoad (OldArgIt-> getType (), &*NewArgIt);
426+ Value* load = builder.CreateLoad (&*NewArgIt);
433427 VMap[&*OldArgIt] = load;
434- llvm::Attribute byValAttr = llvm::Attribute::getWithByValType (M.getContext (), OldArgIt->getType ());
435- NewArgIt->addAttr (byValAttr);
428+ ArgByVal.push_back (&*NewArgIt);
436429 }
437430 else
438431 {
@@ -451,13 +444,21 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
451444 builder.CreateBr (ClonedEntryBB);
452445 MergeBlockIntoPredecessor (ClonedEntryBB);
453446
447+ // Loop through new args and add 'byval' attributes
448+ for (auto arg : ArgByVal)
449+ {
450+ arg->addAttr (llvm::Attribute::getWithByValType (M.getContext (),
451+ IGCLLVM::getNonOpaquePtrEltTy (arg->getType ())));
452+ }
453+
454454 // Now fix the return values
455455 if (retTypeOption == ReturnOpt::RETURN_BY_REF)
456456 {
457457 // Add the 'noalias' and 'sret' attribute to arg0
458458 auto retArg = pNewFunc->arg_begin ();
459459 retArg->addAttr (llvm::Attribute::NoAlias);
460- retArg->addAttr (llvm::Attribute::getWithStructRetType (M.getContext (), pFunc->getReturnType ()));
460+ retArg->addAttr (llvm::Attribute::getWithStructRetType (
461+ M.getContext (), IGCLLVM::getNonOpaquePtrEltTy (retArg->getType ())));
461462
462463 // Loop through all return instructions and store the old return value into the arg0 pointer
463464 const auto ptrSize = DL.getPointerSize ();
@@ -576,7 +577,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
576577 if (callInst->getType ()->isVoidTy () &&
577578 IGCLLVM::getNumArgOperands (callInst) > 0 &&
578579 callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
579- isPromotableStructType (M, callInst->getParamAttr ( 0 , llvm::Attribute::StructRet). getValueAsType (), isStackCall))
580+ isPromotableStructType (M, callInst->getArgOperand ( 0 )-> getType (), isStackCall, true /* retval */ ))
580581 {
581582 opNum++; // Skip the first call operand
582583 retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -607,17 +608,18 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
607608 {
608609 // extend the illegal int to a legal type
609610 IGCLLVM::IRBuilder<> builder (callInst);
610- Value* extend = builder.CreateZExt (callInst-> getOperand (opNum) , LegalizedIntVectorType (M, arg->getType ()));
611+ Value* extend = builder.CreateZExt (arg , LegalizedIntVectorType (M, arg->getType ()));
611612 callArgs.push_back (extend);
612613 ArgAttrVec.push_back (AttributeSet ());
613614 fixArgType = true ;
614615 }
615616 else if (callInst->paramHasAttr (opNum, llvm::Attribute::ByVal) &&
616- isPromotableStructType (M, callInst-> getParamByValType (opNum ), isStackCall))
617+ isPromotableStructType (M, arg-> getType ( ), isStackCall))
617618 {
618619 // Map the new operand to the loaded value of the struct pointer
619620 IGCLLVM::IRBuilder<> builder (callInst);
620- Value* newOp = LoadFromStruct (builder, callInst->getOperand (opNum), callInst->getParamByValType (opNum));
621+ Argument* callArg = IGCLLVM::getArg (*calledFunc, opNum);
622+ Value* newOp = LoadFromStruct (builder, arg, callArg->getParamByValType ());
621623 callArgs.push_back (newOp);
622624 ArgAttrVec.push_back (AttributeSet ());
623625 fixArgType = true ;
@@ -627,7 +629,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
627629 // Create and store operand as an alloca, then pass as argument
628630 IGCLLVM::IRBuilder<> builder (callInst);
629631 Value* allocaV = builder.CreateAlloca (arg->getType ());
630- builder.CreateStore (callInst-> getOperand (opNum) , allocaV);
632+ builder.CreateStore (arg , allocaV);
631633 callArgs.push_back (allocaV);
632634 auto byValAttr = llvm::Attribute::getWithByValType (M.getContext (), arg->getType ());
633635 auto argAttrs = AttributeSet::get (M.getContext (), { byValAttr });
@@ -657,7 +659,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
657659 }
658660 Type* retType =
659661 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (callInst->getContext ()) :
660- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getFunction ( )->getArg ( 0 )) :
662+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getArgOperand ( 0 )->getType ( )) :
661663 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, callInst->getType ()) :
662664 callInst->getType ();
663665 newFnTy = FunctionType::get (retType, argTypes, false );
0 commit comments