Skip to content

Commit ea0af73

Browse files
krystian-andrzejewskiigcbot
authored andcommitted
Replacing usages of getNonOpaquePtrEltTy in AdaptorCommon - part 1
This change set is to prepare for removing dependencies on references to non-opaque pointers.
1 parent ed91167 commit ea0af73

File tree

5 files changed

+500
-43
lines changed

5 files changed

+500
-43
lines changed

IGC/AdaptorCommon/LegalizeFunctionSignatures.cpp

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,9 @@ inline Type* LegalizedIntVectorType(const Module& M, Type* ty)
139139
}
140140

141141
// Returns true for structs smaller than 'structSize' and only contains primitive types
142-
inline bool isLegalStructType(const Module& M, Type* ty, unsigned structSize)
142+
inline bool isLegalStructType(const Module& M, StructType* sTy, unsigned structSize)
143143
{
144-
IGC_ASSERT(ty->isStructTy());
145144
const DataLayout& DL = M.getDataLayout();
146-
StructType* sTy = dyn_cast<StructType>(ty);
147145
if (sTy && DL.getStructLayout(sTy)->getSizeInBits() <= structSize)
148146
{
149147
for (const auto* EltTy : sTy->elements())
@@ -163,7 +161,7 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
163161
{
164162
if (ty->isStructTy())
165163
{
166-
return isLegalStructType(M, ty, MAX_STRUCT_SIZE_IN_BITS);
164+
return isLegalStructType(M, cast<StructType>(ty), MAX_STRUCT_SIZE_IN_BITS);
167165
}
168166
else if (ty->isArrayTy())
169167
{
@@ -174,16 +172,14 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
174172
return true;
175173
}
176174

177-
// Check if a struct pointer argument is promotable to pass-by-value
178-
inline bool isPromotableStructType(const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false)
175+
inline bool isPromotableStructType(const Module& M, Type* pointeeType, bool isStackCall)
179176
{
180177
if (IGC_IS_FLAG_DISABLED(EnableByValStructArgPromotion))
181178
return false;
182-
183179
const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
184-
if (ty->isPointerTy() && IGCLLVM::getNonOpaquePtrEltTy(ty)->isStructTy())
180+
if (isa<StructType>(pointeeType))
185181
{
186-
return isLegalStructType(M, IGCLLVM::getNonOpaquePtrEltTy(ty), maxSize);
182+
return isLegalStructType(M, cast<StructType>(pointeeType), maxSize);
187183
}
188184
return false;
189185
}
@@ -194,18 +190,29 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
194190
if (F->getReturnType()->isVoidTy() &&
195191
!F->arg_empty() &&
196192
F->arg_begin()->hasStructRetAttr() &&
197-
isPromotableStructType(M, F->arg_begin()->getType(), F->hasFnAttribute("visaStackCall"), true))
193+
isPromotableStructType(M, F->arg_begin()->getParamStructRetType(), F->hasFnAttribute("visaStackCall")))
198194
{
199195
return true;
200196
}
201197
return false;
202198
}
203199

204200
// Promotes struct pointer to struct type
205-
inline Type* PromotedStructValueType(const Module& M, const Type* ty)
201+
inline StructType* PromotedStructValueType(const Module& M, const Argument* arg)
206202
{
207-
IGC_ASSERT(ty->isPointerTy() && IGCLLVM::getNonOpaquePtrEltTy(ty)->isStructTy());
208-
return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy(ty));
203+
if (arg->getType()->isPointerTy())
204+
{
205+
if (arg->hasStructRetAttr() && arg->getParamStructRetType()->isStructTy())
206+
{
207+
return cast<StructType>(arg->getParamStructRetType());
208+
}
209+
else if (arg->hasByValAttr() && arg->getParamByValType()->isStructTy())
210+
{
211+
return cast<StructType>(arg->getParamByValType());
212+
}
213+
}
214+
IGC_ASSERT_MESSAGE(0, "Not implemented case");
215+
return nullptr;
209216
}
210217

211218
// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
@@ -218,7 +225,7 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* s
218225
for (unsigned i = 0; i < sTy->getNumElements(); i++)
219226
{
220227
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
221-
Value* elementPtr = builder.CreateInBoundsGEP(strPtr, indices);
228+
Value* elementPtr = builder.CreateInBoundsGEP(strVal->getType(), strPtr, indices);
222229
Value* element = builder.CreateExtractValue(strVal, i);
223230
builder.CreateStore(element, elementPtr);
224231
}
@@ -235,7 +242,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
235242
for (unsigned i = 0; i < sTy->getNumElements(); i++)
236243
{
237244
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
238-
Value* elementPtr = builder.CreateInBoundsGEP(strPtr, indices);
245+
Value* elementPtr = builder.CreateInBoundsGEP(ty, strPtr, indices);
239246
Value* element = builder.CreateLoad(sTy->getElementType(i), elementPtr);
240247
strVal = builder.CreateInsertValue(strVal, element, i);
241248
}
@@ -308,10 +315,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
308315
argTypes.push_back(LegalizedIntVectorType(M, ai->getType()));
309316
}
310317
else if (ai->hasByValAttr() &&
311-
isPromotableStructType(M, ai->getType(), isStackCall))
318+
isPromotableStructType(M, ai->getParamByValType(), isStackCall))
312319
{
313320
fixArgType = true;
314-
argTypes.push_back(PromotedStructValueType(M, ai->getType()));
321+
argTypes.push_back(PromotedStructValueType(M, ai));
315322
}
316323
else if (!isLegalSignatureType(M, ai->getType(), isStackCall))
317324
{
@@ -329,7 +336,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
329336
// Clone function with new signature
330337
Type* returnType =
331338
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy(M.getContext()) :
332-
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, pFunc->arg_begin()->getType()) :
339+
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, pFunc->arg_begin()) :
333340
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType(M, pFunc->getReturnType()) :
334341
pFunc->getReturnType();
335342
FunctionType* signature = FunctionType::get(returnType, argTypes, false);
@@ -393,13 +400,12 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
393400
if (OldArgIt == pFunc->arg_begin() && retTypeOption == ReturnOpt::RETURN_STRUCT)
394401
{
395402
// Create a temp alloca to map the old argument. This will be removed later by SROA.
396-
tempAllocaForSRetPointerTy = PromotedStructValueType(M, OldArgIt->getType());
403+
tempAllocaForSRetPointerTy = PromotedStructValueType(M, OldArgIt);
397404
tempAllocaForSRetPointer = builder.CreateAlloca(tempAllocaForSRetPointerTy);
398405
tempAllocaForSRetPointer = builder.CreateAddrSpaceCast(tempAllocaForSRetPointer, OldArgIt->getType());
399406
VMap[&*OldArgIt] = tempAllocaForSRetPointer;
400407
continue;
401408
}
402-
403409
NewArgIt->setName(OldArgIt->getName());
404410
if (!isLegalIntVectorType(M, OldArgIt->getType()))
405411
{
@@ -408,24 +414,25 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
408414
VMap[&*OldArgIt] = trunc;
409415
}
410416
else if (OldArgIt->hasByValAttr() &&
411-
isPromotableStructType(M, OldArgIt->getType(), isStackCall))
417+
isPromotableStructType(M, OldArgIt->getParamByValType(), isStackCall))
412418
{
419+
AllocaInst* newArgPtr = builder.CreateAlloca(OldArgIt->getParamByValType());
413420
// remove "byval" attrib since it is now pass-by-value
414421
NewArgIt->removeAttr(llvm::Attribute::ByVal);
415-
Value* newArgPtr = builder.CreateAlloca(NewArgIt->getType());
416422
StoreToStruct(builder, &*NewArgIt, newArgPtr);
417423
// cast back to original addrspace
418424
IGC_ASSERT(OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GENERIC ||
419-
OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_PRIVATE);
420-
newArgPtr = builder.CreateAddrSpaceCast(newArgPtr, OldArgIt->getType());
421-
VMap[&*OldArgIt] = newArgPtr;
425+
OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_PRIVATE);
426+
llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast(newArgPtr, OldArgIt->getType());
427+
VMap[&*OldArgIt] = castedNewArgPtr;
422428
}
423429
else if (!isLegalSignatureType(M, OldArgIt->getType(), isStackCall))
424430
{
425431
// Load from pointer arg
426-
Value* load = builder.CreateLoad(&*NewArgIt);
432+
Value* load = builder.CreateLoad(OldArgIt->getType(), &*NewArgIt);
427433
VMap[&*OldArgIt] = load;
428-
ArgByVal.push_back(&*NewArgIt);
434+
llvm::Attribute byValAttr = llvm::Attribute::getWithByValType(M.getContext(), OldArgIt->getType());
435+
NewArgIt->addAttr(byValAttr);
429436
}
430437
else
431438
{
@@ -444,21 +451,13 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
444451
builder.CreateBr(ClonedEntryBB);
445452
MergeBlockIntoPredecessor(ClonedEntryBB);
446453

447-
// Loop through new args and add 'byval' attributes
448-
for (auto arg : ArgByVal)
449-
{
450-
arg->addAttr(llvm::Attribute::getWithByValType(M.getContext(),
451-
IGCLLVM::getNonOpaquePtrEltTy(arg->getType())));
452-
}
453-
454454
// Now fix the return values
455455
if (retTypeOption == ReturnOpt::RETURN_BY_REF)
456456
{
457457
// Add the 'noalias' and 'sret' attribute to arg0
458458
auto retArg = pNewFunc->arg_begin();
459459
retArg->addAttr(llvm::Attribute::NoAlias);
460-
retArg->addAttr(llvm::Attribute::getWithStructRetType(
461-
M.getContext(), IGCLLVM::getNonOpaquePtrEltTy(retArg->getType())));
460+
retArg->addAttr(llvm::Attribute::getWithStructRetType(M.getContext(), pFunc->getReturnType()));
462461

463462
// Loop through all return instructions and store the old return value into the arg0 pointer
464463
const auto ptrSize = DL.getPointerSize();
@@ -577,7 +576,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
577576
if (callInst->getType()->isVoidTy() &&
578577
IGCLLVM::getNumArgOperands(callInst) > 0 &&
579578
callInst->paramHasAttr(0, llvm::Attribute::StructRet) &&
580-
isPromotableStructType(M, callInst->getArgOperand(0)->getType(), isStackCall, true /* retval */))
579+
isPromotableStructType(M, callInst->getParamAttr(0, llvm::Attribute::StructRet).getValueAsType(), isStackCall))
581580
{
582581
opNum++; // Skip the first call operand
583582
retTypeOption = ReturnOpt::RETURN_STRUCT;
@@ -608,18 +607,17 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
608607
{
609608
// extend the illegal int to a legal type
610609
IGCLLVM::IRBuilder<> builder(callInst);
611-
Value* extend = builder.CreateZExt(arg, LegalizedIntVectorType(M, arg->getType()));
610+
Value* extend = builder.CreateZExt(callInst->getOperand(opNum), LegalizedIntVectorType(M, arg->getType()));
612611
callArgs.push_back(extend);
613612
ArgAttrVec.push_back(AttributeSet());
614613
fixArgType = true;
615614
}
616615
else if (callInst->paramHasAttr(opNum, llvm::Attribute::ByVal) &&
617-
isPromotableStructType(M, arg->getType(), isStackCall))
616+
isPromotableStructType(M, callInst->getParamByValType(opNum), isStackCall))
618617
{
619618
// Map the new operand to the loaded value of the struct pointer
620619
IGCLLVM::IRBuilder<> builder(callInst);
621-
Argument* callArg = IGCLLVM::getArg(*calledFunc, opNum);
622-
Value* newOp = LoadFromStruct(builder, arg, callArg->getParamByValType());
620+
Value* newOp = LoadFromStruct(builder, callInst->getOperand(opNum), callInst->getParamByValType(opNum));
623621
callArgs.push_back(newOp);
624622
ArgAttrVec.push_back(AttributeSet());
625623
fixArgType = true;
@@ -629,7 +627,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
629627
// Create and store operand as an alloca, then pass as argument
630628
IGCLLVM::IRBuilder<> builder(callInst);
631629
Value* allocaV = builder.CreateAlloca(arg->getType());
632-
builder.CreateStore(arg, allocaV);
630+
builder.CreateStore(callInst->getOperand(opNum), allocaV);
633631
callArgs.push_back(allocaV);
634632
auto byValAttr = llvm::Attribute::getWithByValType(M.getContext(), arg->getType());
635633
auto argAttrs = AttributeSet::get(M.getContext(), { byValAttr });
@@ -659,7 +657,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
659657
}
660658
Type* retType =
661659
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy(callInst->getContext()) :
662-
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, callInst->getArgOperand(0)->getType()) :
660+
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, callInst->getFunction()->getArg(0)) :
663661
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType(M, callInst->getType()) :
664662
callInst->getType();
665663
newFnTy = FunctionType::get(retType, argTypes, false);

0 commit comments

Comments
 (0)