Skip to content

Commit 375018b

Browse files
krystian-andrzejewskiigcbot
authored andcommitted
Replacing usages of getNonOpaquePtrEltTy in AdaptorCommon - part 1
This change set is to prepare for removing dependencies on references to non-opaque pointers.
1 parent 9fd3398 commit 375018b

File tree

1 file changed

+86
-39
lines changed

1 file changed

+86
-39
lines changed

IGC/AdaptorCommon/LegalizeFunctionSignatures.cpp

Lines changed: 86 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,27 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
175175
}
176176

177177
// Check if a struct pointer argument is promotable to pass-by-value
178-
inline bool isPromotableStructType(const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false)
178+
inline bool isPromotableStructType(const Module& M, const Argument* arg, bool isStackCall, bool isReturnValue = false)
179179
{
180180
if (IGC_IS_FLAG_DISABLED(EnableByValStructArgPromotion))
181181
return false;
182182

183183
const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
184-
if (ty->isPointerTy() && IGCLLVM::getNonOpaquePtrEltTy(ty)->isStructTy())
184+
llvm::Type* structType = nullptr;
185+
if (arg->getType()->isPointerTy())
185186
{
186-
return isLegalStructType(M, IGCLLVM::getNonOpaquePtrEltTy(ty), maxSize);
187+
if (arg->hasStructRetAttr() && arg->getParamStructRetType()->isStructTy())
188+
{
189+
structType = arg->getParamStructRetType();
190+
}
191+
else if (arg->hasByValAttr() && arg->getParamByValType()->isStructTy())
192+
{
193+
structType = arg->getParamByValType();
194+
}
195+
}
196+
if (structType)
197+
{
198+
return isLegalStructType(M, structType, maxSize);
187199
}
188200
return false;
189201
}
@@ -193,23 +205,33 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
193205
{
194206
if (F->getReturnType()->isVoidTy() &&
195207
!F->arg_empty() &&
196-
F->arg_begin()->hasStructRetAttr() &&
197-
isPromotableStructType(M, F->arg_begin()->getType(), F->hasFnAttribute("visaStackCall"), true))
208+
isPromotableStructType(M, F->arg_begin(), F->hasFnAttribute("visaStackCall"), true))
198209
{
199210
return true;
200211
}
201212
return false;
202213
}
203214

204215
// Promotes struct pointer to struct type
205-
inline Type* PromotedStructValueType(const Module& M, const Type* ty)
216+
inline StructType* PromotedStructValueType(const Module& M, const Argument* arg)
206217
{
207-
IGC_ASSERT(ty->isPointerTy() && IGCLLVM::getNonOpaquePtrEltTy(ty)->isStructTy());
208-
return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy(ty));
218+
if (arg->getType()->isPointerTy())
219+
{
220+
if (arg->hasStructRetAttr() && arg->getParamStructRetType()->isStructTy())
221+
{
222+
return cast<StructType>(arg->getParamStructRetType());
223+
}
224+
else if (arg->hasByValAttr() && arg->getParamByValType()->isStructTy())
225+
{
226+
return cast<StructType>(arg->getParamByValType());
227+
}
228+
}
229+
IGC_ASSERT_MESSAGE(0, "Not implemented case");
230+
return nullptr;
209231
}
210232

211233
// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
212-
inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* strPtr)
234+
inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, AllocaInst* strPtr)
213235
{
214236
IGC_ASSERT(strPtr->getType()->isPointerTy());
215237
IGC_ASSERT(strVal->getType()->isStructTy());
@@ -218,12 +240,45 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* s
218240
for (unsigned i = 0; i < sTy->getNumElements(); i++)
219241
{
220242
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
221-
Value* elementPtr = builder.CreateInBoundsGEP(strPtr, indices);
243+
Value* elementPtr = builder.CreateInBoundsGEP(strPtr->getAllocatedType(), strPtr, indices);
222244
Value* element = builder.CreateExtractValue(strVal, i);
223245
builder.CreateStore(element, elementPtr);
224246
}
225247
}
226248

249+
// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
250+
inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Argument* strPtr)
251+
{
252+
IGC_ASSERT(strPtr->getType()->isPointerTy());
253+
IGC_ASSERT(strVal->getType()->isStructTy());
254+
if (strPtr->hasStructRetAttr() && strPtr->getParamStructRetType()->isStructTy())
255+
{
256+
StructType* sTy = cast<StructType>(strVal->getType());
257+
for (unsigned i = 0; i < sTy->getNumElements(); i++)
258+
{
259+
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
260+
Value* elementPtr = builder.CreateInBoundsGEP(strPtr->getParamStructRetType(), strPtr, indices);
261+
Value* element = builder.CreateExtractValue(strVal, i);
262+
builder.CreateStore(element, elementPtr);
263+
}
264+
}
265+
else if (strPtr->hasByValAttr() && strPtr->getParamByValType()->isStructTy())
266+
{
267+
StructType* sTy = cast<StructType>(strVal->getType());
268+
for (unsigned i = 0; i < sTy->getNumElements(); i++)
269+
{
270+
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
271+
Value* elementPtr = builder.CreateInBoundsGEP(strPtr->getParamByValType(), strPtr, indices);
272+
Value* element = builder.CreateExtractValue(strVal, i);
273+
builder.CreateStore(element, elementPtr);
274+
}
275+
}
276+
else
277+
{
278+
IGC_ASSERT_MESSAGE(0, "Unsupported case: no information about the pointee type");
279+
}
280+
}
281+
227282
// BE does not handle struct load/store, so instead load each element from the GEP struct pointer and insert it into the struct value
228283
inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type* ty)
229284
{
@@ -235,7 +290,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
235290
for (unsigned i = 0; i < sTy->getNumElements(); i++)
236291
{
237292
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
238-
Value* elementPtr = builder.CreateInBoundsGEP(strPtr, indices);
293+
Value* elementPtr = builder.CreateInBoundsGEP(ty, strPtr, indices);
239294
Value* element = builder.CreateLoad(sTy->getElementType(i), elementPtr);
240295
strVal = builder.CreateInsertValue(strVal, element, i);
241296
}
@@ -308,10 +363,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
308363
argTypes.push_back(LegalizedIntVectorType(M, ai->getType()));
309364
}
310365
else if (ai->hasByValAttr() &&
311-
isPromotableStructType(M, ai->getType(), isStackCall))
366+
isPromotableStructType(M, ai, isStackCall))
312367
{
313368
fixArgType = true;
314-
argTypes.push_back(PromotedStructValueType(M, ai->getType()));
369+
argTypes.push_back(PromotedStructValueType(M, ai));
315370
}
316371
else if (!isLegalSignatureType(M, ai->getType(), isStackCall))
317372
{
@@ -329,7 +384,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
329384
// Clone function with new signature
330385
Type* returnType =
331386
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy(M.getContext()) :
332-
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, pFunc->arg_begin()->getType()) :
387+
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, pFunc->arg_begin()) :
333388
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType(M, pFunc->getReturnType()) :
334389
pFunc->getReturnType();
335390
FunctionType* signature = FunctionType::get(returnType, argTypes, false);
@@ -393,7 +448,7 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
393448
if (OldArgIt == pFunc->arg_begin() && retTypeOption == ReturnOpt::RETURN_STRUCT)
394449
{
395450
// Create a temp alloca to map the old argument. This will be removed later by SROA.
396-
tempAllocaForSRetPointerTy = PromotedStructValueType(M, OldArgIt->getType());
451+
tempAllocaForSRetPointerTy = PromotedStructValueType(M, OldArgIt);
397452
tempAllocaForSRetPointer = builder.CreateAlloca(tempAllocaForSRetPointerTy);
398453
tempAllocaForSRetPointer = builder.CreateAddrSpaceCast(tempAllocaForSRetPointer, OldArgIt->getType());
399454
VMap[&*OldArgIt] = tempAllocaForSRetPointer;
@@ -408,24 +463,25 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
408463
VMap[&*OldArgIt] = trunc;
409464
}
410465
else if (OldArgIt->hasByValAttr() &&
411-
isPromotableStructType(M, OldArgIt->getType(), isStackCall))
466+
isPromotableStructType(M, OldArgIt, isStackCall))
412467
{
468+
AllocaInst* newArgPtr = builder.CreateAlloca(OldArgIt->getParamByValType());
413469
// remove "byval" attrib since it is now pass-by-value
414470
NewArgIt->removeAttr(llvm::Attribute::ByVal);
415-
Value* newArgPtr = builder.CreateAlloca(NewArgIt->getType());
416471
StoreToStruct(builder, &*NewArgIt, newArgPtr);
417472
// cast back to original addrspace
418473
IGC_ASSERT(OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GENERIC ||
419474
OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_PRIVATE);
420-
newArgPtr = builder.CreateAddrSpaceCast(newArgPtr, OldArgIt->getType());
421-
VMap[&*OldArgIt] = newArgPtr;
475+
llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast(newArgPtr, OldArgIt->getType());
476+
VMap[&*OldArgIt] = castedNewArgPtr;
422477
}
423478
else if (!isLegalSignatureType(M, OldArgIt->getType(), isStackCall))
424479
{
425480
// Load from pointer arg
426-
Value* load = builder.CreateLoad(&*NewArgIt);
481+
Value* load = builder.CreateLoad(OldArgIt->getType(), &*NewArgIt);
427482
VMap[&*OldArgIt] = load;
428-
ArgByVal.push_back(&*NewArgIt);
483+
llvm::Attribute byValAttr = llvm::Attribute::getWithByValType(M.getContext(), OldArgIt->getType());
484+
NewArgIt->addAttr(byValAttr);
429485
}
430486
else
431487
{
@@ -444,21 +500,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
444500
builder.CreateBr(ClonedEntryBB);
445501
MergeBlockIntoPredecessor(ClonedEntryBB);
446502

447-
// Loop through new args and add 'byval' attributes
448-
for (auto arg : ArgByVal)
449-
{
450-
arg->addAttr(llvm::Attribute::getWithByValType(M.getContext(),
451-
IGCLLVM::getNonOpaquePtrEltTy(arg->getType())));
452-
}
453-
454503
// Now fix the return values
455504
if (retTypeOption == ReturnOpt::RETURN_BY_REF)
456505
{
457506
// Add the 'noalias' and 'sret' attribute to arg0
458507
auto retArg = pNewFunc->arg_begin();
459508
retArg->addAttr(llvm::Attribute::NoAlias);
460-
retArg->addAttr(llvm::Attribute::getWithStructRetType(
461-
M.getContext(), IGCLLVM::getNonOpaquePtrEltTy(retArg->getType())));
509+
retArg->addAttr(llvm::Attribute::getWithStructRetType(M.getContext(), pFunc->getReturnType()));
462510

463511
// Loop through all return instructions and store the old return value into the arg0 pointer
464512
const auto ptrSize = DL.getPointerSize();
@@ -577,7 +625,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
577625
if (callInst->getType()->isVoidTy() &&
578626
IGCLLVM::getNumArgOperands(callInst) > 0 &&
579627
callInst->paramHasAttr(0, llvm::Attribute::StructRet) &&
580-
isPromotableStructType(M, callInst->getArgOperand(0)->getType(), isStackCall, true /* retval */))
628+
isPromotableStructType(M, callInst->getCalledFunction()->getArg(0), isStackCall, true /* retval */))
581629
{
582630
opNum++; // Skip the first call operand
583631
retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -603,23 +651,22 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
603651
// Check call operands if it needs to be replaced
604652
for (; opNum < IGCLLVM::getNumArgOperands(callInst); opNum++)
605653
{
606-
Value* arg = callInst->getArgOperand(opNum);
654+
Argument* arg = IGCLLVM::getArg(*calledFunc, opNum);
607655
if (!isLegalIntVectorType(M, arg->getType()))
608656
{
609657
// extend the illegal int to a legal type
610658
IGCLLVM::IRBuilder<> builder(callInst);
611-
Value* extend = builder.CreateZExt(arg, LegalizedIntVectorType(M, arg->getType()));
659+
Value* extend = builder.CreateZExt(callInst->getOperand(opNum), LegalizedIntVectorType(M, arg->getType()));
612660
callArgs.push_back(extend);
613661
ArgAttrVec.push_back(AttributeSet());
614662
fixArgType = true;
615663
}
616664
else if (callInst->paramHasAttr(opNum, llvm::Attribute::ByVal) &&
617-
isPromotableStructType(M, arg->getType(), isStackCall))
665+
isPromotableStructType(M, arg, isStackCall))
618666
{
619667
// Map the new operand to the loaded value of the struct pointer
620668
IGCLLVM::IRBuilder<> builder(callInst);
621-
Argument* callArg = IGCLLVM::getArg(*calledFunc, opNum);
622-
Value* newOp = LoadFromStruct(builder, arg, callArg->getParamByValType());
669+
Value* newOp = LoadFromStruct(builder, callInst->getOperand(opNum), arg->getParamByValType());
623670
callArgs.push_back(newOp);
624671
ArgAttrVec.push_back(AttributeSet());
625672
fixArgType = true;
@@ -629,7 +676,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
629676
// Create and store operand as an alloca, then pass as argument
630677
IGCLLVM::IRBuilder<> builder(callInst);
631678
Value* allocaV = builder.CreateAlloca(arg->getType());
632-
builder.CreateStore(arg, allocaV);
679+
builder.CreateStore(callInst->getOperand(opNum), allocaV);
633680
callArgs.push_back(allocaV);
634681
auto byValAttr = llvm::Attribute::getWithByValType(M.getContext(), arg->getType());
635682
auto argAttrs = AttributeSet::get(M.getContext(), { byValAttr });
@@ -659,7 +706,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
659706
}
660707
Type* retType =
661708
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy(callInst->getContext()) :
662-
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, callInst->getArgOperand(0)->getType()) :
709+
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, callInst->getFunction()->getArg(0)) :
663710
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType(M, callInst->getType()) :
664711
callInst->getType();
665712
newFnTy = FunctionType::get(retType, argTypes, false);
@@ -690,7 +737,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
690737
else if (retTypeOption == ReturnOpt::RETURN_STRUCT)
691738
{
692739
// Store the struct value into the orginal pointer operand
693-
StoreToStruct(builder, newCallInst, callInst->getArgOperand(0));
740+
StoreToStruct(builder, newCallInst, callInst->getCalledFunction()->getArg(0));
694741
}
695742
else if (retTypeOption == ReturnOpt::RETURN_LEGAL_INT)
696743
{

0 commit comments

Comments
 (0)