@@ -90,12 +90,12 @@ bool LegalizeFunctionSignatures::runOnModule(Module& M)
9090 return true ;
9191}
9292
93- inline bool isLegalSignatureType (const Type* ty)
93+ // If the return type size is greater than the allowed size, we convert the return value to pass-by-pointer
94+ inline bool isLegalReturnType (const Type* ty)
9495{
95- // Structs are by default illegal unless they are promotable
96- if (ty->getTypeID () == Type::ArrayTyID) return false ;
97- if (ty->getTypeID () == Type::StructTyID) return false ;
98- return true ;
96+ // check return type size
97+ // return ty->getPrimitiveSizeInBits() <= MAX_RETVAL_SIZE_IN_BITS;
98+ return true ; // allow all return sizes
9999}
100100
101101// Check if an int or int-vector argument type is a power of two
@@ -257,16 +257,16 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
257257 auto ei = pFunc->arg_end ();
258258
259259 // Create the new function signature by replacing the illegal types
260- if (FunctionHasPromotableSRetArg (M, pFunc))
261- {
262- promoteSRetType = true ;
263- ai++; // Skip adding the first arg
264- }
265- else if (isStackCall && !isLegalSignatureType (pFunc->getReturnType ()))
260+ if (isStackCall && !isLegalReturnType (pFunc->getReturnType ()))
266261 {
267262 legalizeReturnType = true ;
268263 argTypes.push_back (PointerType::get (pFunc->getReturnType (), 0 ));
269264 }
265+ else if (FunctionHasPromotableSRetArg (M, pFunc))
266+ {
267+ promoteSRetType = true ;
268+ ai++; // Skip adding the first arg
269+ }
270270
271271 for (; ai != ei; ai++)
272272 {
@@ -281,11 +281,6 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
281281 fixArgType = true ;
282282 argTypes.push_back (PromotedStructValueType (M, ai->getType ()));
283283 }
284- else if (!isLegalSignatureType (ai->getType ()))
285- {
286- fixArgType = true ;
287- argTypes.push_back (PointerType::get (ai->getType (), 0 ));
288- }
289284 else
290285 {
291286 argTypes.push_back (ai->getType ());
@@ -341,13 +336,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
341336 bool isStackCall = pFunc->hasFnAttribute (" visaStackCall" );
342337 Value* tempAllocaForSRetPointer = nullptr ;
343338
344- if (FunctionHasPromotableSRetArg (M, pFunc)) {
345- promoteSRetType = true ;
346- }
347- else if (isStackCall && !isLegalSignatureType (pFunc->getReturnType ())) {
339+ if (isStackCall && !isLegalReturnType (pFunc->getReturnType ())) {
348340 legalizeReturnType = true ;
349341 ++NewArgIt; // Skip first argument that we added.
350342 }
343+ else if (FunctionHasPromotableSRetArg (M, pFunc)) {
344+ promoteSRetType = true ;
345+ }
351346
352347 // Fix the usages of arguments that have changed
353348 BasicBlock* EntryBB = BasicBlock::Create (M.getContext (), " " , pNewFunc);
@@ -379,12 +374,6 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
379374 StoreToStruct (builder, &*NewArgIt, newArgPtr);
380375 VMap[&*OldArgIt] = newArgPtr;
381376 }
382- else if (!isLegalSignatureType (OldArgIt->getType ()))
383- {
384- // Load from pointer arg
385- Value* load = builder.CreateLoad (&*NewArgIt);
386- VMap[&*OldArgIt] = load;
387- }
388377 else
389378 {
390379 // No change, map old arg to new arg
@@ -401,27 +390,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
401390 builder.CreateBr (ClonedEntryBB);
402391 MergeBlockIntoPredecessor (ClonedEntryBB);
403392
404- // Loop through new args and add attributes
405- for (auto oldIt = pFunc->arg_begin (), it = pNewFunc->arg_begin (), ie = pNewFunc->arg_end (); it != ie; it++)
406- {
407- if (legalizeReturnType && it == pNewFunc->arg_begin ())
408- {
409- it->addAttr (llvm::Attribute::NoAlias);
410- it->addAttr (llvm::Attribute::StructRet);
411- continue ;
412- }
413- else if (it->getType ()->isPointerTy () &&
414- oldIt->getType () != it->getType () &&
415- !isLegalSignatureType (oldIt->getType ()))
416- {
417- it->addAttr (llvm::Attribute::ByVal);
418- }
419- oldIt++;
420- }
421-
422393 // Now fix the return values
423394 if (legalizeReturnType)
424395 {
396+ // Add "noalias" and "sret" to the return argument
397+ auto retArg = pNewFunc->arg_begin ();
398+ retArg->addAttr (llvm::Attribute::NoAlias);
399+ retArg->addAttr (llvm::Attribute::StructRet);
425400 // Loop through all return instructions and store the old return value into the arg0 pointer
426401 const auto ptrSize = DL.getPointerSize ();
427402 for (auto RetInst : Returns)
@@ -524,15 +499,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
524499
525500 // Check return type
526501 Value* returnPtr = nullptr ;
527- if (callInst->getType ()->isVoidTy () &&
528- callInst->getNumArgOperands () > 0 &&
529- callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
530- isPromotableStructType (M, callInst->getArgOperand (0 )->getType (), isStackCall, true /* retval */ ))
531- {
532- opNum++; // Skip the first call operand
533- promoteSRetType = true ;
534- }
535- else if (isStackCall && !isLegalSignatureType (callInst->getType ()))
502+ if (isStackCall && !isLegalReturnType (callInst->getType ()))
536503 {
537504 // Create an alloca for the return type
538505 IGCLLVM::IRBuilder<> builder (callInst);
@@ -545,6 +512,14 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
545512 ArgAttrVec.push_back (retAttrib);
546513 legalizeReturnType = true ;
547514 }
515+ else if (callInst->getType ()->isVoidTy () &&
516+ callInst->getNumArgOperands () > 0 &&
517+ callInst->paramHasAttr (0 , llvm::Attribute::StructRet) &&
518+ isPromotableStructType (M, callInst->getArgOperand (0 )->getType (), isStackCall, true /* retval */ ))
519+ {
520+ opNum++; // Skip the first call operand
521+ promoteSRetType = true ;
522+ }
548523
549524 // Check call operands if it needs to be replaced
550525 for (; opNum < callInst->getNumArgOperands (); opNum++)
@@ -569,18 +544,6 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
569544 ArgAttrVec.push_back (AttributeSet ());
570545 fixArgType = true ;
571546 }
572- else if (!isLegalSignatureType (arg->getType ()))
573- {
574- // Create and store operand as an alloca, then pass as argument
575- IGCLLVM::IRBuilder<> builder (callInst);
576- Value* alloca = builder.CreateAlloca (arg->getType ());
577- builder.CreateStore (arg, alloca);
578- callArgs.push_back (alloca);
579- AttributeSet argAttrib;
580- argAttrib = argAttrib.addAttribute (M.getContext (), llvm::Attribute::ByVal);
581- ArgAttrVec.push_back (argAttrib);
582- fixArgType = true ;
583- }
584547 else
585548 {
586549 // legal argument
0 commit comments