@@ -139,9 +139,11 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
139139}
140140
141141// Returns true for structs smaller than 'structSize' and only contains primitive types
142- inline bool isLegalStructType (const Module& M, StructType* sTy , unsigned structSize)
142+ inline bool isLegalStructType (const Module& M, Type* ty , unsigned structSize)
143143{
144+ IGC_ASSERT (ty->isStructTy ());
144145 const DataLayout& DL = M.getDataLayout ();
146+ StructType* sTy = dyn_cast<StructType>(ty);
145147 if (sTy && DL.getStructLayout (sTy )->getSizeInBits () <= structSize)
146148 {
147149 for (const auto * EltTy : sTy ->elements ())
@@ -161,7 +163,7 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
161163 {
162164 if (ty->isStructTy ())
163165 {
164- return isLegalStructType (M, cast<StructType>(ty) , MAX_STRUCT_SIZE_IN_BITS);
166+ return isLegalStructType (M, ty , MAX_STRUCT_SIZE_IN_BITS);
165167 }
166168 else if (ty->isArrayTy ())
167169 {
@@ -172,31 +174,16 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
172174 return true ;
173175}
174176
175- inline bool isPromotableStructType (const Module& M, Type* pointeeType, bool isStackCall)
177+ // Check if a struct pointer argument is promotable to pass-by-value
178+ inline bool isPromotableStructType (const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false )
176179{
177180 if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
178181 return false ;
179- const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
180- if (isa<StructType>(pointeeType))
181- {
182- return isLegalStructType (M, cast<StructType>(pointeeType), maxSize);
183- }
184- return false ;
185- }
186182
187- // Check if a struct pointer argument is promotable to pass-by-value
188- inline bool isPromotableStructType (const Module& M, const Argument* arg, bool isStackCall)
189- {
190- if (arg->getType ()->isPointerTy ())
183+ const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
184+ if (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ())
191185 {
192- if (arg->hasStructRetAttr ())
193- {
194- return isPromotableStructType (M, arg->getParamStructRetType (), isStackCall);
195- }
196- else if (arg->hasByValAttr ())
197- {
198- return isPromotableStructType (M, arg->getParamByValType (), isStackCall);
199- }
186+ return isLegalStructType (M, IGCLLVM::getNonOpaquePtrEltTy (ty), maxSize);
200187 }
201188 return false ;
202189}
@@ -206,33 +193,23 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
206193{
207194 if (F->getReturnType ()->isVoidTy () &&
208195 !F->arg_empty () &&
209- isPromotableStructType (M, F->arg_begin (), F->hasFnAttribute (" visaStackCall" )))
196+ F->arg_begin ()->hasStructRetAttr () &&
197+ isPromotableStructType (M, F->arg_begin ()->getType (), F->hasFnAttribute (" visaStackCall" ), true ))
210198 {
211199 return true ;
212200 }
213201 return false ;
214202}
215203
216204// Promotes struct pointer to struct type
217- inline StructType * PromotedStructValueType (const Module& M, const Argument* arg )
205+ inline Type * PromotedStructValueType (const Module& M, const Type* ty )
218206{
219- if (arg->getType ()->isPointerTy ())
220- {
221- if (arg->hasStructRetAttr () && arg->getParamStructRetType ()->isStructTy ())
222- {
223- return cast<StructType>(arg->getParamStructRetType ());
224- }
225- else if (arg->hasByValAttr () && arg->getParamByValType ()->isStructTy ())
226- {
227- return cast<StructType>(arg->getParamByValType ());
228- }
229- }
230- IGC_ASSERT_MESSAGE (0 , " Not implemented case" );
231- return nullptr ;
207+ IGC_ASSERT (ty->isPointerTy () && IGCLLVM::getNonOpaquePtrEltTy (ty)->isStructTy ());
208+ return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy (ty));
232209}
233210
234211// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
235- inline void StoreToStruct (IGCLLVM::IRBuilder<>& builder, Value* strVal, AllocaInst * strPtr)
212+ inline void StoreToStruct (IGCLLVM::IRBuilder<>& builder, Value* strVal, Value * strPtr)
236213{
237214 IGC_ASSERT (strPtr->getType ()->isPointerTy ());
238215 IGC_ASSERT (strVal->getType ()->isStructTy ());
@@ -241,45 +218,12 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, AllocaIn
241218 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
242219 {
243220 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
244- Value* elementPtr = builder.CreateInBoundsGEP (strPtr-> getAllocatedType (), strPtr , indices);
221+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
245222 Value* element = builder.CreateExtractValue (strVal, i);
246223 builder.CreateStore (element, elementPtr);
247224 }
248225}
249226
250- // BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
251- inline void StoreToStruct (IGCLLVM::IRBuilder<>& builder, Value* strVal, Argument* strPtr)
252- {
253- IGC_ASSERT (strPtr->getType ()->isPointerTy ());
254- IGC_ASSERT (strVal->getType ()->isStructTy ());
255- if (strPtr->hasStructRetAttr () && strPtr->getParamStructRetType ()->isStructTy ())
256- {
257- StructType* sTy = cast<StructType>(strVal->getType ());
258- for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
259- {
260- Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
261- Value* elementPtr = builder.CreateInBoundsGEP (strPtr->getParamStructRetType (), strPtr, indices);
262- Value* element = builder.CreateExtractValue (strVal, i);
263- builder.CreateStore (element, elementPtr);
264- }
265- }
266- else if (strPtr->hasByValAttr () && strPtr->getParamByValType ()->isStructTy ())
267- {
268- StructType* sTy = cast<StructType>(strVal->getType ());
269- for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
270- {
271- Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
272- Value* elementPtr = builder.CreateInBoundsGEP (strPtr->getParamByValType (), strPtr, indices);
273- Value* element = builder.CreateExtractValue (strVal, i);
274- builder.CreateStore (element, elementPtr);
275- }
276- }
277- else
278- {
279- IGC_ASSERT_MESSAGE (0 , " Unsupported case: no information about the pointee type" );
280- }
281- }
282-
283227// BE does not handle struct load/store, so instead load each element from the GEP struct pointer and insert it into the struct value
284228inline Value* LoadFromStruct (IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type* ty)
285229{
@@ -291,7 +235,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
291235 for (unsigned i = 0 ; i < sTy ->getNumElements (); i++)
292236 {
293237 Value* indices[] = { builder.getInt32 (0 ), builder.getInt32 (i) };
294- Value* elementPtr = builder.CreateInBoundsGEP (ty, strPtr, indices);
238+ Value* elementPtr = builder.CreateInBoundsGEP (strPtr, indices);
295239 Value* element = builder.CreateLoad (sTy ->getElementType (i), elementPtr);
296240 strVal = builder.CreateInsertValue (strVal, element, i);
297241 }
@@ -364,10 +308,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
364308 argTypes.push_back (LegalizedIntVectorType (M, ai->getType ()));
365309 }
366310 else if (ai->hasByValAttr () &&
367- isPromotableStructType (M, ai, isStackCall))
311+ isPromotableStructType (M, ai-> getType () , isStackCall))
368312 {
369313 fixArgType = true ;
370- argTypes.push_back (PromotedStructValueType (M, ai));
314+ argTypes.push_back (PromotedStructValueType (M, ai-> getType () ));
371315 }
372316 else if (!isLegalSignatureType (M, ai->getType (), isStackCall))
373317 {
@@ -385,7 +329,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
385329 // Clone function with new signature
386330 Type* returnType =
387331 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (M.getContext ()) :
388- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()) :
332+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()-> getType () ) :
389333 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, pFunc->getReturnType ()) :
390334 pFunc->getReturnType ();
391335 FunctionType* signature = FunctionType::get (returnType, argTypes, false );
@@ -449,7 +393,7 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
449393 if (OldArgIt == pFunc->arg_begin () && retTypeOption == ReturnOpt::RETURN_STRUCT)
450394 {
451395 // Create a temp alloca to map the old argument. This will be removed later by SROA.
452- tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt);
396+ tempAllocaForSRetPointerTy = PromotedStructValueType (M, OldArgIt-> getType () );
453397 tempAllocaForSRetPointer = builder.CreateAlloca (tempAllocaForSRetPointerTy);
454398 tempAllocaForSRetPointer = builder.CreateAddrSpaceCast (tempAllocaForSRetPointer, OldArgIt->getType ());
455399 VMap[&*OldArgIt] = tempAllocaForSRetPointer;
@@ -464,25 +408,24 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
464408 VMap[&*OldArgIt] = trunc;
465409 }
466410 else if (OldArgIt->hasByValAttr () &&
467- isPromotableStructType (M, OldArgIt, isStackCall))
411+ isPromotableStructType (M, OldArgIt-> getType () , isStackCall))
468412 {
469- AllocaInst* newArgPtr = builder.CreateAlloca (OldArgIt->getParamByValType ());
470413 // remove "byval" attrib since it is now pass-by-value
471414 NewArgIt->removeAttr (llvm::Attribute::ByVal);
415+ Value* newArgPtr = builder.CreateAlloca (NewArgIt->getType ());
472416 StoreToStruct (builder, &*NewArgIt, newArgPtr);
473417 // cast back to original addrspace
474418 IGC_ASSERT (OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_GENERIC ||
475419 OldArgIt->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_PRIVATE);
476- llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
477- VMap[&*OldArgIt] = castedNewArgPtr ;
420+ newArgPtr = builder.CreateAddrSpaceCast (newArgPtr, OldArgIt->getType ());
421+ VMap[&*OldArgIt] = newArgPtr ;
478422 }
479423 else if (!isLegalSignatureType (M, OldArgIt->getType (), isStackCall))
480424 {
481425 // Load from pointer arg
482- Value* load = builder.CreateLoad (OldArgIt-> getType (), &*NewArgIt);
426+ Value* load = builder.CreateLoad (&*NewArgIt);
483427 VMap[&*OldArgIt] = load;
484- llvm::Attribute byValAttr = llvm::Attribute::getWithByValType (M.getContext (), OldArgIt->getType ());
485- NewArgIt->addAttr (byValAttr);
428+ ArgByVal.push_back (&*NewArgIt);
486429 }
487430 else
488431 {
@@ -501,13 +444,21 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
501444 builder.CreateBr (ClonedEntryBB);
502445 MergeBlockIntoPredecessor (ClonedEntryBB);
503446
447+ // Loop through new args and add 'byval' attributes
448+ for (auto arg : ArgByVal)
449+ {
450+ arg->addAttr (llvm::Attribute::getWithByValType (M.getContext (),
451+ IGCLLVM::getNonOpaquePtrEltTy (arg->getType ())));
452+ }
453+
504454 // Now fix the return values
505455 if (retTypeOption == ReturnOpt::RETURN_BY_REF)
506456 {
507457 // Add the 'noalias' and 'sret' attribute to arg0
508458 auto retArg = pNewFunc->arg_begin ();
509459 retArg->addAttr (llvm::Attribute::NoAlias);
510- retArg->addAttr (llvm::Attribute::getWithStructRetType (M.getContext (), pFunc->getReturnType ()));
460+ retArg->addAttr (llvm::Attribute::getWithStructRetType (
461+ M.getContext (), IGCLLVM::getNonOpaquePtrEltTy (retArg->getType ())));
511462
512463 // Loop through all return instructions and store the old return value into the arg0 pointer
513464 const auto ptrSize = DL.getPointerSize ();
@@ -626,7 +577,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
626577 if (callInst->getType ()->isVoidTy () &&
627578 IGCLLVM::getNumArgOperands (callInst) > 0 &&
628579 callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
629- isPromotableStructType (M, callInst->getCalledFunction ( )->getArg ( 0 ), isStackCall))
580+ isPromotableStructType (M, callInst->getArgOperand ( 0 )->getType ( ), isStackCall, true /* retval */ ))
630581 {
631582 opNum++; // Skip the first call operand
632583 retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -657,17 +608,18 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
657608 {
658609 // extend the illegal int to a legal type
659610 IGCLLVM::IRBuilder<> builder (callInst);
660- Value* extend = builder.CreateZExt (callInst-> getOperand (opNum) , LegalizedIntVectorType (M, arg->getType ()));
611+ Value* extend = builder.CreateZExt (arg , LegalizedIntVectorType (M, arg->getType ()));
661612 callArgs.push_back (extend);
662613 ArgAttrVec.push_back (AttributeSet ());
663614 fixArgType = true ;
664615 }
665616 else if (callInst->paramHasAttr (opNum, llvm::Attribute::ByVal) &&
666- isPromotableStructType (M, callInst-> getParamByValType (opNum ), isStackCall))
617+ isPromotableStructType (M, arg-> getType ( ), isStackCall))
667618 {
668619 // Map the new operand to the loaded value of the struct pointer
669620 IGCLLVM::IRBuilder<> builder (callInst);
670- Value* newOp = LoadFromStruct (builder, callInst->getOperand (opNum), callInst->getParamByValType (opNum));
621+ Argument* callArg = IGCLLVM::getArg (*calledFunc, opNum);
622+ Value* newOp = LoadFromStruct (builder, arg, callArg->getParamByValType ());
671623 callArgs.push_back (newOp);
672624 ArgAttrVec.push_back (AttributeSet ());
673625 fixArgType = true ;
@@ -677,7 +629,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
677629 // Create and store operand as an alloca, then pass as argument
678630 IGCLLVM::IRBuilder<> builder (callInst);
679631 Value* allocaV = builder.CreateAlloca (arg->getType ());
680- builder.CreateStore (callInst-> getOperand (opNum) , allocaV);
632+ builder.CreateStore (arg , allocaV);
681633 callArgs.push_back (allocaV);
682634 auto byValAttr = llvm::Attribute::getWithByValType (M.getContext (), arg->getType ());
683635 auto argAttrs = AttributeSet::get (M.getContext (), { byValAttr });
@@ -707,7 +659,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
707659 }
708660 Type* retType =
709661 retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (callInst->getContext ()) :
710- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getFunction ( )->getArg ( 0 )) :
662+ retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getArgOperand ( 0 )->getType ( )) :
711663 retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, callInst->getType ()) :
712664 callInst->getType ();
713665 newFnTy = FunctionType::get (retType, argTypes, false );
@@ -738,7 +690,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
738690 else if (retTypeOption == ReturnOpt::RETURN_STRUCT)
739691 {
740692 // Store the struct value into the orginal pointer operand
741- StoreToStruct (builder, newCallInst, callInst->getCalledFunction ()-> getArg (0 ));
693+ StoreToStruct (builder, newCallInst, callInst->getArgOperand (0 ));
742694 }
743695 else if (retTypeOption == ReturnOpt::RETURN_LEGAL_INT)
744696 {
0 commit comments