@@ -175,15 +175,27 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
175175}
176176
177177// Check if a struct pointer argument is promotable to pass-by-value
178- inline bool isPromotableStructType (const Module& M, const Type* ty , bool isStackCall, bool isReturnValue = false )
178+ inline bool isPromotableStructType (const Module& M, const Argument* arg , bool isStackCall, bool isReturnValue = false )
179179{
180180 if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
181181 return false ;
182182
183183 const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
184- if (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ())
184+ llvm::Type* structType = nullptr ;
185+ if (arg->getType ()->isPointerTy ())
185186 {
186- return isLegalStructType (M, IGCLLVM::getNonOpaquePtrEltTy (ty), maxSize);
187+ if (arg->hasStructRetAttr () && arg->getParamStructRetType ()->isStructTy ())
188+ {
189+ structType = arg->getParamStructRetType ();
190+ }
191+ else if (arg->hasByValAttr () && arg->getParamByValType ()->isStructTy ())
192+ {
193+ structType = arg->getParamByValType ();
194+ }
195+ }
196+ if (structType)
197+ {
198+ return isLegalStructType (M, structType, maxSize);
187199 }
188200 return false ;
189201}
@@ -193,23 +205,33 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
193205{
194206 if (F->getReturnType ()->isVoidTy () &&
195207 !F->arg_empty () &&
196- F->arg_begin ()->hasStructRetAttr () &&
197- isPromotableStructType (M, F->arg_begin ()->getType (), F->hasFnAttribute (" visaStackCall" ), true ))
208+ isPromotableStructType (M, F->arg_begin (), F->hasFnAttribute (" visaStackCall" ), true ))
198209 {
199210 return true ;
200211 }
201212 return false ;
202213}
203214
204215// Promotes struct pointer to struct type
205- inline Type * PromotedStructValueType (const Module& M, const Type* ty )
216+ inline StructType * PromotedStructValueType (const Module& M, const Argument* arg )
206217{
207- IGC_ASSERT (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ());
208- return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy (ty));
218+ if (arg->getType ()->isPointerTy ())
219+ {
220+ if (arg->hasStructRetAttr () && arg->getParamStructRetType ()->isStructTy ())
221+ {
222+ return cast<StructType>(arg->getParamStructRetType ());
223+ }
224+ else if (arg->hasByValAttr () && arg->getParamByValType ()->isStructTy ())
225+ {
226+ return cast<StructType>(arg->getParamByValType ());
227+ }
228+ }
229+ IGC_ASSERT_MESSAGE (0 , " Not implemented case" );
230+ return nullptr ;
209231}
210232
211233// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
212- inline void StoreToStruct (IGCLLVM::IRBuilder<>& builder, Value* strVal, Value * strPtr)
234+ inline void StoreToStruct (IGCLLVM::IRBuilder<>& builder, Value* strVal, AllocaInst * strPtr)
213235{
214236 IGC_ASSERT (strPtr->getType ()->isPointerTy ());
215237 IGC_ASSERT (strVal->getType ()->isStructTy ());
@@ -218,12 +240,45 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* s
218240 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
219241 {
220242 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
221- Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
243+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr-> getAllocatedType (), strPtr , indices);
222244 Value* element = builder.CreateExtractValue (strVal, i);
223245 builder.CreateStore (element, elementPtr);
224246 }
225247}
226248
249+ // BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
250+ inline void StoreToStruct (IGCLLVM::IRBuilder<>& builder, Value* strVal, Argument* strPtr)
251+ {
252+ IGC_ASSERT (strPtr->getType ()->isPointerTy ());
253+ IGC_ASSERT (strVal->getType ()->isStructTy ());
254+ if (strPtr->hasStructRetAttr () && strPtr->getParamStructRetType ()->isStructTy ())
255+ {
256+ StructType* sTy = cast<StructType>(strVal->getType ());
257+ for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
258+ {
259+ Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
260+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr->getParamStructRetType (), strPtr, indices);
261+ Value* element = builder.CreateExtractValue (strVal, i);
262+ builder.CreateStore (element, elementPtr);
263+ }
264+ }
265+ else if (strPtr->hasByValAttr () && strPtr->getParamByValType ()->isStructTy ())
266+ {
267+ StructType* sTy = cast<StructType>(strVal->getType ());
268+ for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
269+ {
270+ Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
271+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr->getParamByValType (), strPtr, indices);
272+ Value* element = builder.CreateExtractValue (strVal, i);
273+ builder.CreateStore (element, elementPtr);
274+ }
275+ }
276+ else
277+ {
278+ IGC_ASSERT_MESSAGE (0 , " Unsupported case: no information about the pointee type" );
279+ }
280+ }
281+
227282// BE does not handle struct load/store, so instead load each element from the GEP struct pointer and insert it into the struct value
228283inline Value* LoadFromStruct (IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type* ty)
229284{
@@ -235,7 +290,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
235290 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
236291 {
237292 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
238- Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
293+ Value* elementPtr = builder.CreateInBoundsGEP (ty, strPtr, indices);
239294 Value* element = builder.CreateLoad (sTy ->getElementType (i), elementPtr);
240295 strVal = builder.CreateInsertValue (strVal, element, i);
241296 }
@@ -308,10 +363,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
308363 argTypes.push_back (LegalizedIntVectorType (M, ai->getType ()));
309364 }
310365 else if (ai->hasByValAttr () &&
311- isPromotableStructType (M, ai-> getType () , isStackCall))
366+ isPromotableStructType (M, ai, isStackCall))
312367 {
313368 fixArgType = true ;
314- argTypes.push_back (PromotedStructValueType (M, ai-> getType () ));
369+ argTypes.push_back (PromotedStructValueType (M, ai));
315370 }
316371 else if (!isLegalSignatureType (M, ai->getType (), isStackCall))
317372 {
@@ -329,7 +384,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
329384 // Clone function with new signature
330385 Type* returnType =
331386 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (M.getContext ()) :
332- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()-> getType () ) :
387+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()) :
333388 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, pFunc->getReturnType ()) :
334389 pFunc->getReturnType ();
335390 FunctionType* signature = FunctionType::get (returnType, argTypes, false );
@@ -393,7 +448,7 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
393448 if (OldArgIt == pFunc->arg_begin () && retTypeOption == ReturnOpt::RETURN_STRUCT)
394449 {
395450 // Create a temp alloca to map the old argument. This will be removed later by SROA.
396- tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt-> getType () );
451+ tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt);
397452 tempAllocaForSRetPointer = builder.CreateAlloca (tempAllocaForSRetPointerTy);
398453 tempAllocaForSRetPointer = builder.CreateAddrSpaceCast (tempAllocaForSRetPointer, OldArgIt->getType ());
399454 VMap[&*OldArgIt] = tempAllocaForSRetPointer;
@@ -408,24 +463,25 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
408463 VMap[&*OldArgIt] = trunc;
409464 }
410465 else if (OldArgIt->hasByValAttr () &&
411- isPromotableStructType (M, OldArgIt-> getType () , isStackCall))
466+ isPromotableStructType (M, OldArgIt, isStackCall))
412467 {
468+ AllocaInst* newArgPtr = builder.CreateAlloca (OldArgIt->getParamByValType ());
413469 // remove "byval" attrib since it is now pass-by-value
414470 NewArgIt->removeAttr (llvm::Attribute::ByVal);
415- Value* newArgPtr = builder.CreateAlloca (NewArgIt->getType ());
416471 StoreToStruct (builder, &*NewArgIt, newArgPtr);
417472 // cast back to original addrspace
418473 IGC_ASSERT (OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_GENERIC ||
419474 OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
420- newArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
421- VMap[&*OldArgIt] = newArgPtr ;
475+ llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
476+ VMap[&*OldArgIt] = castedNewArgPtr ;
422477 }
423478 else if (!isLegalSignatureType (M, OldArgIt->getType (), isStackCall))
424479 {
425480 // Load from pointer arg
426- Value* load = builder.CreateLoad (&*NewArgIt);
481+ Value* load = builder.CreateLoad (OldArgIt-> getType (), &*NewArgIt);
427482 VMap[&*OldArgIt] = load;
428- ArgByVal.push_back (&*NewArgIt);
483+ llvm::Attribute byValAttr = llvm::Attribute::getWithByValType (M.getContext (), OldArgIt->getType ());
484+ NewArgIt->addAttr (byValAttr);
429485 }
430486 else
431487 {
@@ -444,21 +500,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
444500 builder.CreateBr (ClonedEntryBB);
445501 MergeBlockIntoPredecessor (ClonedEntryBB);
446502
447- // Loop through new args and add 'byval' attributes
448- for (auto arg : ArgByVal)
449- {
450- arg->addAttr (llvm::Attribute::getWithByValType (M.getContext (),
451- IGCLLVM::getNonOpaquePtrEltTy (arg->getType ())));
452- }
453-
454503 // Now fix the return values
455504 if (retTypeOption == ReturnOpt::RETURN_BY_REF)
456505 {
457506 // Add the 'noalias' and 'sret' attribute to arg0
458507 auto retArg = pNewFunc->arg_begin ();
459508 retArg->addAttr (llvm::Attribute::NoAlias);
460- retArg->addAttr (llvm::Attribute::getWithStructRetType (
461- M.getContext (), IGCLLVM::getNonOpaquePtrEltTy (retArg->getType ())));
509+ retArg->addAttr (llvm::Attribute::getWithStructRetType (M.getContext (), pFunc->getReturnType ()));
462510
463511 // Loop through all return instructions and store the old return value into the arg0 pointer
464512 const auto ptrSize = DL.getPointerSize ();
@@ -577,7 +625,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
577625 if (callInst->getType ()->isVoidTy () &&
578626 IGCLLVM::getNumArgOperands (callInst) > 0 &&
579627 callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
580- isPromotableStructType (M, callInst->getArgOperand ( 0 )->getType ( ), isStackCall, true /* retval */ ))
628+ isPromotableStructType (M, callInst->getCalledFunction ( )->getArg ( 0 ), isStackCall, true /* retval */ ))
581629 {
582630 opNum++; // Skip the first call operand
583631 retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -603,23 +651,22 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
603651 // Check call operands if it needs to be replaced
604652 for (; opNum < IGCLLVM::getNumArgOperands (callInst); opNum++)
605653 {
606- Value * arg = callInst-> getArgOperand ( opNum);
654+ Argument * arg = IGCLLVM::getArg (*calledFunc, opNum);
607655 if (!isLegalIntVectorType (M, arg->getType ()))
608656 {
609657 // extend the illegal int to a legal type
610658 IGCLLVM::IRBuilder<> builder (callInst);
611- Value* extend = builder.CreateZExt (arg , LegalizedIntVectorType (M, arg->getType ()));
659+ Value* extend = builder.CreateZExt (callInst-> getOperand (opNum) , LegalizedIntVectorType (M, arg->getType ()));
612660 callArgs.push_back (extend);
613661 ArgAttrVec.push_back (AttributeSet ());
614662 fixArgType = true ;
615663 }
616664 else if (callInst->paramHasAttr (opNum, llvm::Attribute::ByVal) &&
617- isPromotableStructType (M, arg-> getType () , isStackCall))
665+ isPromotableStructType (M, arg, isStackCall))
618666 {
619667 // Map the new operand to the loaded value of the struct pointer
620668 IGCLLVM::IRBuilder<> builder (callInst);
621- Argument* callArg = IGCLLVM::getArg (*calledFunc, opNum);
622- Value* newOp = LoadFromStruct (builder, arg, callArg->getParamByValType ());
669+ Value* newOp = LoadFromStruct (builder, callInst->getOperand (opNum), arg->getParamByValType ());
623670 callArgs.push_back (newOp);
624671 ArgAttrVec.push_back (AttributeSet ());
625672 fixArgType = true ;
@@ -629,7 +676,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
629676 // Create and store operand as an alloca, then pass as argument
630677 IGCLLVM::IRBuilder<> builder (callInst);
631678 Value* allocaV = builder.CreateAlloca (arg->getType ());
632- builder.CreateStore (arg , allocaV);
679+ builder.CreateStore (callInst-> getOperand (opNum) , allocaV);
633680 callArgs.push_back (allocaV);
634681 auto byValAttr = llvm::Attribute::getWithByValType (M.getContext (), arg->getType ());
635682 auto argAttrs = AttributeSet::get (M.getContext (), { byValAttr });
@@ -659,7 +706,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
659706 }
660707 Type* retType =
661708 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (callInst->getContext ()) :
662- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getArgOperand ( 0 )->getType ( )) :
709+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getFunction ( )->getArg ( 0 )) :
663710 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, callInst->getType ()) :
664711 callInst->getType ();
665712 newFnTy = FunctionType::get (retType, argTypes, false );
@@ -690,7 +737,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
690737 else if (retTypeOption == ReturnOpt::RETURN_STRUCT)
691738 {
692739 // Store the struct value into the orginal pointer operand
693- StoreToStruct (builder, newCallInst, callInst->getArgOperand (0 ));
740+ StoreToStruct (builder, newCallInst, callInst->getCalledFunction ()-> getArg (0 ));
694741 }
695742 else if (retTypeOption == ReturnOpt::RETURN_LEGAL_INT)
696743 {
0 commit comments