@@ -90,14 +90,6 @@ bool LegalizeFunctionSignatures::runOnModule(Module& M)
9090 return true ;
9191}
9292
93- // If the return type size is greater than the allowed size, we convert the return value to pass-by-pointer
94- inline bool isLegalReturnType (const Type* ty)
95- {
96- // check return type size
97- // return ty->getPrimitiveSizeInBits() <= MAX_RETVAL_SIZE_IN_BITS;
98- return true ; // allow all return sizes
99- }
100-
10193// Check if an int or int-vector argument type is a power of two
10294inline bool isLegalIntVectorType (const Module& M, Type* ty)
10395{
@@ -137,34 +129,56 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
137129 IGCLLVM::FixedVectorType::get (IntegerType::get (M.getContext (), newSize), (unsigned )cast<IGCLLVM::FixedVectorType>(ty)->getNumElements ());
138130}
139131
140- // Returns true for small structures that only contain primitive types
141- inline bool isPromotableStructType (const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false )
132+ // Returns true for structs smaller than 'structSize' and only contains primitive types
133+ inline bool isLegalStructType (const Module& M, Type* ty, unsigned structSize )
142134{
143- if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
144- return false ;
145-
146- const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS
147- : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
148-
135+ IGC_ASSERT (ty->isStructTy ());
149136 const DataLayout& DL = M.getDataLayout ();
137+ StructType* sTy = dyn_cast<StructType>(ty);
138+ if (sTy && DL.getStructLayout (sTy )->getSizeInBits () <= structSize)
139+ {
140+ for (const auto * EltTy : sTy ->elements ())
141+ {
142+ // Check if all elements are primitive types
143+ if (!EltTy->isSingleValueType () || EltTy->isVectorTy ())
144+ return false ;
145+ // Avoid int64 and fp64 because of unimplemented InstExpander::visitInsertValue
146+ // and InstExpander::visitExtractValue in the Emu64Ops pass.
147+ if (EltTy->isIntegerTy (64 ) || EltTy->isDoubleTy ())
148+ return false ;
149+ }
150+ return true ;
151+ }
152+ return false ;
153+ }
150154
151- if (ty->isPointerTy ())
155+ inline bool isLegalSignatureType (const Module& M, Type* ty, bool isStackCall)
156+ {
157+ if (isStackCall)
152158 {
153- StructType* sTy = dyn_cast<StructType>(ty->getPointerElementType ());
154- if (sTy && DL.getStructLayout (sTy )->getSizeInBits () <= maxSize)
159+ if (ty->isStructTy ())
155160 {
156- for (const auto * EltTy : sTy ->elements ())
157- {
158- // Check if all elements are primitive types
159- if (!EltTy->isSingleValueType () || EltTy->isVectorTy ())
160- return false ;
161- // Avoid int64 and fp64 because of unimplemented InstExpander::visitInsertValue
162- // and InstExpander::visitExtractValue in the Emu64Ops pass.
163- if (EltTy->isIntegerTy (64 ) || EltTy->isDoubleTy ())
164- return false ;
165- }
166- return true ;
161+ return isLegalStructType (M, ty, MAX_STRUCT_SIZE_IN_BITS);
167162 }
163+ else if (ty->isArrayTy ())
164+ {
165+ return false ;
166+ }
167+ }
168+ // Are all subroutine types legal?
169+ return true ;
170+ }
171+
172+ // Check if a struct pointer argument is promotable to pass-by-value
173+ inline bool isPromotableStructType (const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false )
174+ {
175+ if (IGC_IS_FLAG_DISABLED (EnableByValStructArgPromotion))
176+ return false ;
177+
178+ const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
179+ if (ty->isPointerTy () && ty->getPointerElementType ()->isStructTy ())
180+ {
181+ return isLegalStructType (M, ty->getPointerElementType (), maxSize);
168182 }
169183 return false ;
170184}
@@ -257,16 +271,16 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
257271 auto ei = pFunc->arg_end ();
258272
259273 // Create the new function signature by replacing the illegal types
260- if (isStackCall && !isLegalReturnType (pFunc->getReturnType ()))
261- {
262- legalizeReturnType = true ;
263- argTypes.push_back (PointerType::get (pFunc->getReturnType (), 0 ));
264- }
265- else if (FunctionHasPromotableSRetArg (M, pFunc))
274+ if (FunctionHasPromotableSRetArg (M, pFunc))
266275 {
267276 promoteSRetType = true ;
268277 ai++; // Skip adding the first arg
269278 }
279+ else if (!isLegalSignatureType (M, pFunc->getReturnType (), isStackCall))
280+ {
281+ legalizeReturnType = true ;
282+ argTypes.push_back (PointerType::get (pFunc->getReturnType (), 0 ));
283+ }
270284
271285 for (; ai != ei; ai++)
272286 {
@@ -281,6 +295,11 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
281295 fixArgType = true ;
282296 argTypes.push_back (PromotedStructValueType (M, ai->getType ()));
283297 }
298+ else if (!isLegalSignatureType (M, ai->getType (), isStackCall))
299+ {
300+ fixArgType = true ;
301+ argTypes.push_back (PointerType::get (ai->getType (), 0 ));
302+ }
284303 else
285304 {
286305 argTypes.push_back (ai->getType ());
@@ -335,14 +354,15 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
335354 bool promoteSRetType = false ;
336355 bool isStackCall = pFunc->hasFnAttribute (" visaStackCall" );
337356 Value* tempAllocaForSRetPointer = nullptr ;
357+ llvm::SmallVector<llvm::Argument*, 8 > ArgByVal;
338358
339- if (isStackCall && !isLegalReturnType (pFunc->getReturnType ())) {
359+ if (FunctionHasPromotableSRetArg (M, pFunc)) {
360+ promoteSRetType = true ;
361+ }
362+ else if (!isLegalSignatureType (M, pFunc->getReturnType (), isStackCall)) {
340363 legalizeReturnType = true ;
341364 ++NewArgIt; // Skip first argument that we added.
342365 }
343- else if (FunctionHasPromotableSRetArg (M, pFunc)) {
344- promoteSRetType = true ;
345- }
346366
347367 // Fix the usages of arguments that have changed
348368 BasicBlock* EntryBB = BasicBlock::Create (M.getContext (), " " , pNewFunc);
@@ -374,6 +394,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
374394 StoreToStruct (builder, &*NewArgIt, newArgPtr);
375395 VMap[&*OldArgIt] = newArgPtr;
376396 }
397+ else if (!isLegalSignatureType (M, OldArgIt->getType (), isStackCall))
398+ {
399+ // Load from pointer arg
400+ Value* load = builder.CreateLoad (&*NewArgIt);
401+ VMap[&*OldArgIt] = load;
402+ ArgByVal.push_back (&*NewArgIt);
403+ }
377404 else
378405 {
379406 // No change, map old arg to new arg
@@ -390,13 +417,20 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
390417 builder.CreateBr (ClonedEntryBB);
391418 MergeBlockIntoPredecessor (ClonedEntryBB);
392419
420+ // Loop through new args and add 'byval' attributes
421+ for (auto arg : ArgByVal)
422+ {
423+ arg->addAttr (llvm::Attribute::ByVal);
424+ }
425+
393426 // Now fix the return values
394427 if (legalizeReturnType)
395428 {
396- // Add " noalias" and " sret" to the return argument
429+ // Add the ' noalias' and ' sret' attribute to arg0
397430 auto retArg = pNewFunc->arg_begin ();
398431 retArg->addAttr (llvm::Attribute::NoAlias);
399432 retArg->addAttr (llvm::Attribute::StructRet);
433+
400434 // Loop through all return instructions and store the old return value into the arg0 pointer
401435 const auto ptrSize = DL.getPointerSize ();
402436 for (auto RetInst : Returns)
@@ -499,7 +533,15 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
499533
500534 // Check return type
501535 Value* returnPtr = nullptr ;
502- if (isStackCall && !isLegalReturnType (callInst->getType ()))
536+ if (callInst->getType ()->isVoidTy () &&
537+ callInst->getNumArgOperands () > 0 &&
538+ callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
539+ isPromotableStructType (M, callInst->getArgOperand (0 )->getType (), isStackCall, true /* retval */ ))
540+ {
541+ opNum++; // Skip the first call operand
542+ promoteSRetType = true ;
543+ }
544+ else if (!isLegalSignatureType (M, callInst->getType (), isStackCall))
503545 {
504546 // Create an alloca for the return type
505547 IGCLLVM::IRBuilder<> builder (callInst);
@@ -512,14 +554,6 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
512554 ArgAttrVec.push_back (retAttrib);
513555 legalizeReturnType = true ;
514556 }
515- else if (callInst->getType ()->isVoidTy () &&
516- callInst->getNumArgOperands () > 0 &&
517- callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
518- isPromotableStructType (M, callInst->getArgOperand (0 )->getType (), isStackCall, true /* retval */ ))
519- {
520- opNum++; // Skip the first call operand
521- promoteSRetType = true ;
522- }
523557
524558 // Check call operands if it needs to be replaced
525559 for (; opNum < callInst->getNumArgOperands (); opNum++)
@@ -544,6 +578,18 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
544578 ArgAttrVec.push_back (AttributeSet ());
545579 fixArgType = true ;
546580 }
581+ else if (!isLegalSignatureType (M, arg->getType (), isStackCall))
582+ {
583+ // Create and store operand as an alloca, then pass as argument
584+ IGCLLVM::IRBuilder<> builder (callInst);
585+ Value* allocaV = builder.CreateAlloca (arg->getType ());
586+ builder.CreateStore (arg, allocaV);
587+ callArgs.push_back (allocaV);
588+ AttributeSet argAttrib;
589+ argAttrib = argAttrib.addAttribute (M.getContext (), llvm::Attribute::ByVal);
590+ ArgAttrVec.push_back (argAttrib);
591+ fixArgType = true ;
592+ }
547593 else
548594 {
549595 // legal argument
0 commit comments