Skip to content

Commit 68f38ec

Browse files
dlei6gigcbot
authored andcommitted
Add support for illegal stackcall args/retval pass-by-pointer
Support passing arrays and matrices (which are a special struct type) as pointers.
1 parent 9454535 commit 68f38ec

File tree

1 file changed

+96
-50
lines changed

1 file changed

+96
-50
lines changed

IGC/AdaptorCommon/LegalizeFunctionSignatures.cpp

Lines changed: 96 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,6 @@ bool LegalizeFunctionSignatures::runOnModule(Module& M)
9090
return true;
9191
}
9292

93-
// If the return type size is greater than the allowed size, we convert the return value to pass-by-pointer
94-
inline bool isLegalReturnType(const Type* ty)
95-
{
96-
// check return type size
97-
//return ty->getPrimitiveSizeInBits() <= MAX_RETVAL_SIZE_IN_BITS;
98-
return true; // allow all return sizes
99-
}
100-
10193
// Check if an int or int-vector argument type is a power of two
10294
inline bool isLegalIntVectorType(const Module& M, Type* ty)
10395
{
@@ -137,34 +129,56 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
137129
IGCLLVM::FixedVectorType::get(IntegerType::get(M.getContext(), newSize), (unsigned)cast<IGCLLVM::FixedVectorType>(ty)->getNumElements());
138130
}
139131

140-
// Returns true for small structures that only contain primitive types
141-
inline bool isPromotableStructType(const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false)
132+
// Returns true for structs smaller than 'structSize' and only contains primitive types
133+
inline bool isLegalStructType(const Module& M, Type* ty, unsigned structSize)
142134
{
143-
if (IGC_IS_FLAG_DISABLED(EnableByValStructArgPromotion))
144-
return false;
145-
146-
const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS
147-
: MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
148-
135+
IGC_ASSERT(ty->isStructTy());
149136
const DataLayout& DL = M.getDataLayout();
137+
StructType* sTy = dyn_cast<StructType>(ty);
138+
if (sTy && DL.getStructLayout(sTy)->getSizeInBits() <= structSize)
139+
{
140+
for (const auto* EltTy : sTy->elements())
141+
{
142+
// Check if all elements are primitive types
143+
if (!EltTy->isSingleValueType() || EltTy->isVectorTy())
144+
return false;
145+
// Avoid int64 and fp64 because of unimplemented InstExpander::visitInsertValue
146+
// and InstExpander::visitExtractValue in the Emu64Ops pass.
147+
if (EltTy->isIntegerTy(64) || EltTy->isDoubleTy())
148+
return false;
149+
}
150+
return true;
151+
}
152+
return false;
153+
}
150154

151-
if (ty->isPointerTy())
155+
inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
156+
{
157+
if (isStackCall)
152158
{
153-
StructType* sTy = dyn_cast<StructType>(ty->getPointerElementType());
154-
if (sTy && DL.getStructLayout(sTy)->getSizeInBits() <= maxSize)
159+
if (ty->isStructTy())
155160
{
156-
for (const auto* EltTy : sTy->elements())
157-
{
158-
// Check if all elements are primitive types
159-
if (!EltTy->isSingleValueType() || EltTy->isVectorTy())
160-
return false;
161-
// Avoid int64 and fp64 because of unimplemented InstExpander::visitInsertValue
162-
// and InstExpander::visitExtractValue in the Emu64Ops pass.
163-
if (EltTy->isIntegerTy(64) || EltTy->isDoubleTy())
164-
return false;
165-
}
166-
return true;
161+
return isLegalStructType(M, ty, MAX_STRUCT_SIZE_IN_BITS);
167162
}
163+
else if (ty->isArrayTy())
164+
{
165+
return false;
166+
}
167+
}
168+
// Are all subroutine types legal?
169+
return true;
170+
}
171+
172+
// Check if a struct pointer argument is promotable to pass-by-value
173+
inline bool isPromotableStructType(const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false)
174+
{
175+
if (IGC_IS_FLAG_DISABLED(EnableByValStructArgPromotion))
176+
return false;
177+
178+
const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
179+
if (ty->isPointerTy() && ty->getPointerElementType()->isStructTy())
180+
{
181+
return isLegalStructType(M, ty->getPointerElementType(), maxSize);
168182
}
169183
return false;
170184
}
@@ -257,16 +271,16 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
257271
auto ei = pFunc->arg_end();
258272

259273
// Create the new function signature by replacing the illegal types
260-
if (isStackCall && !isLegalReturnType(pFunc->getReturnType()))
261-
{
262-
legalizeReturnType = true;
263-
argTypes.push_back(PointerType::get(pFunc->getReturnType(), 0));
264-
}
265-
else if (FunctionHasPromotableSRetArg(M, pFunc))
274+
if (FunctionHasPromotableSRetArg(M, pFunc))
266275
{
267276
promoteSRetType = true;
268277
ai++; // Skip adding the first arg
269278
}
279+
else if (!isLegalSignatureType(M, pFunc->getReturnType(), isStackCall))
280+
{
281+
legalizeReturnType = true;
282+
argTypes.push_back(PointerType::get(pFunc->getReturnType(), 0));
283+
}
270284

271285
for (; ai != ei; ai++)
272286
{
@@ -281,6 +295,11 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
281295
fixArgType = true;
282296
argTypes.push_back(PromotedStructValueType(M, ai->getType()));
283297
}
298+
else if (!isLegalSignatureType(M, ai->getType(), isStackCall))
299+
{
300+
fixArgType = true;
301+
argTypes.push_back(PointerType::get(ai->getType(), 0));
302+
}
284303
else
285304
{
286305
argTypes.push_back(ai->getType());
@@ -335,14 +354,15 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
335354
bool promoteSRetType = false;
336355
bool isStackCall = pFunc->hasFnAttribute("visaStackCall");
337356
Value* tempAllocaForSRetPointer = nullptr;
357+
llvm::SmallVector<llvm::Argument*, 8> ArgByVal;
338358

339-
if (isStackCall && !isLegalReturnType(pFunc->getReturnType())) {
359+
if (FunctionHasPromotableSRetArg(M, pFunc)) {
360+
promoteSRetType = true;
361+
}
362+
else if (!isLegalSignatureType(M, pFunc->getReturnType(), isStackCall)) {
340363
legalizeReturnType = true;
341364
++NewArgIt; // Skip first argument that we added.
342365
}
343-
else if (FunctionHasPromotableSRetArg(M, pFunc)) {
344-
promoteSRetType = true;
345-
}
346366

347367
// Fix the usages of arguments that have changed
348368
BasicBlock* EntryBB = BasicBlock::Create(M.getContext(), "", pNewFunc);
@@ -374,6 +394,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
374394
StoreToStruct(builder, &*NewArgIt, newArgPtr);
375395
VMap[&*OldArgIt] = newArgPtr;
376396
}
397+
else if (!isLegalSignatureType(M, OldArgIt->getType(), isStackCall))
398+
{
399+
// Load from pointer arg
400+
Value* load = builder.CreateLoad(&*NewArgIt);
401+
VMap[&*OldArgIt] = load;
402+
ArgByVal.push_back(&*NewArgIt);
403+
}
377404
else
378405
{
379406
// No change, map old arg to new arg
@@ -390,13 +417,20 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
390417
builder.CreateBr(ClonedEntryBB);
391418
MergeBlockIntoPredecessor(ClonedEntryBB);
392419

420+
// Loop through new args and add 'byval' attributes
421+
for (auto arg : ArgByVal)
422+
{
423+
arg->addAttr(llvm::Attribute::ByVal);
424+
}
425+
393426
// Now fix the return values
394427
if (legalizeReturnType)
395428
{
396-
// Add "noalias" and "sret" to the return argument
429+
// Add the 'noalias' and 'sret' attribute to arg0
397430
auto retArg = pNewFunc->arg_begin();
398431
retArg->addAttr(llvm::Attribute::NoAlias);
399432
retArg->addAttr(llvm::Attribute::StructRet);
433+
400434
// Loop through all return instructions and store the old return value into the arg0 pointer
401435
const auto ptrSize = DL.getPointerSize();
402436
for (auto RetInst : Returns)
@@ -499,7 +533,15 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
499533

500534
// Check return type
501535
Value* returnPtr = nullptr;
502-
if (isStackCall && !isLegalReturnType(callInst->getType()))
536+
if (callInst->getType()->isVoidTy() &&
537+
callInst->getNumArgOperands() > 0 &&
538+
callInst->paramHasAttr(0, llvm::Attribute::StructRet) &&
539+
isPromotableStructType(M, callInst->getArgOperand(0)->getType(), isStackCall, true /* retval */))
540+
{
541+
opNum++; // Skip the first call operand
542+
promoteSRetType = true;
543+
}
544+
else if (!isLegalSignatureType(M, callInst->getType(), isStackCall))
503545
{
504546
// Create an alloca for the return type
505547
IGCLLVM::IRBuilder<> builder(callInst);
@@ -512,14 +554,6 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
512554
ArgAttrVec.push_back(retAttrib);
513555
legalizeReturnType = true;
514556
}
515-
else if (callInst->getType()->isVoidTy() &&
516-
callInst->getNumArgOperands() > 0 &&
517-
callInst->paramHasAttr(0, llvm::Attribute::StructRet) &&
518-
isPromotableStructType(M, callInst->getArgOperand(0)->getType(), isStackCall, true /* retval */))
519-
{
520-
opNum++; // Skip the first call operand
521-
promoteSRetType = true;
522-
}
523557

524558
// Check call operands if it needs to be replaced
525559
for (; opNum < callInst->getNumArgOperands(); opNum++)
@@ -544,6 +578,18 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
544578
ArgAttrVec.push_back(AttributeSet());
545579
fixArgType = true;
546580
}
581+
else if (!isLegalSignatureType(M, arg->getType(), isStackCall))
582+
{
583+
// Create and store operand as an alloca, then pass as argument
584+
IGCLLVM::IRBuilder<> builder(callInst);
585+
Value* allocaV = builder.CreateAlloca(arg->getType());
586+
builder.CreateStore(arg, allocaV);
587+
callArgs.push_back(allocaV);
588+
AttributeSet argAttrib;
589+
argAttrib = argAttrib.addAttribute(M.getContext(), llvm::Attribute::ByVal);
590+
ArgAttrVec.push_back(argAttrib);
591+
fixArgType = true;
592+
}
547593
else
548594
{
549595
// legal argument

0 commit comments

Comments
 (0)