Skip to content

Commit 858fc93

Browse files
AlexVlxHoney Goyal
authored andcommitted
[SPIRV] Add support for pointers to functions with aggregate args/returns as global variables / constant initialisers (llvm#169595)
This patch does two things: 1. it extends the aggregate arg / ret replacement transform to work on indirect calls / pointers to function. It is somewhat spread out as retrieving the original function type is needed in a few places. In general, we should rethink / rework the entire infrastructure around aggregate arg/ret handling, using an opaque target specific type rather than i32; 2. it enables global variables of pointer to function type, and, more specifically, global variables of a aggregate type (arrays / structures) with pointer to function elements. This also exposes some issues in how we handle pointers to function and lowering indirect function calls, primarily around not using the program address space. These will be handled in a subsequent patch as they'll require somewhat more intrusive surgery, possibly involving modifying the data layout.
1 parent 9313a06 commit 858fc93

File tree

8 files changed

+333
-79
lines changed

8 files changed

+333
-79
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -131,47 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
131131
return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
132132
}
133133

134-
// This code restores function args/retvalue types for composite cases
135-
// because the final types should still be aggregate whereas they're i32
136-
// during the translation to cope with aggregate flattening etc.
137-
static FunctionType *getOriginalFunctionType(const Function &F) {
138-
auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
139-
if (NamedMD == nullptr)
140-
return F.getFunctionType();
141-
142-
Type *RetTy = F.getFunctionType()->getReturnType();
143-
SmallVector<Type *, 4> ArgTypes;
144-
for (auto &Arg : F.args())
145-
ArgTypes.push_back(Arg.getType());
146-
147-
auto ThisFuncMDIt =
148-
std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
149-
return isa<MDString>(N->getOperand(0)) &&
150-
cast<MDString>(N->getOperand(0))->getString() == F.getName();
151-
});
152-
if (ThisFuncMDIt != NamedMD->op_end()) {
153-
auto *ThisFuncMD = *ThisFuncMDIt;
154-
for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
155-
MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
156-
assert(MD && "MDNode operand is expected");
157-
ConstantInt *Const = getConstInt(MD, 0);
158-
if (Const) {
159-
auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
160-
assert(CMeta && "ConstantAsMetadata operand is expected");
161-
assert(Const->getSExtValue() >= -1);
162-
// Currently -1 indicates return value, greater values mean
163-
// argument numbers.
164-
if (Const->getSExtValue() == -1)
165-
RetTy = CMeta->getType();
166-
else
167-
ArgTypes[Const->getSExtValue()] = CMeta->getType();
168-
}
169-
}
170-
}
171-
172-
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
173-
}
174-
175134
static SPIRV::AccessQualifier::AccessQualifier
176135
getArgAccessQual(const Function &F, unsigned ArgIdx) {
177136
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
@@ -204,7 +163,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
204163
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
205164
getArgAccessQual(F, ArgIdx);
206165

207-
Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
166+
Type *OriginalArgType =
167+
SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx);
208168

209169
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
210170
// be legally reassigned later).
@@ -429,7 +389,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
429389
auto MRI = MIRBuilder.getMRI();
430390
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
431391
MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
432-
FunctionType *FTy = getOriginalFunctionType(F);
392+
FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
433393
Type *FRetTy = FTy->getReturnType();
434394
if (isUntypedPointerTy(FRetTy)) {
435395
if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
@@ -514,10 +474,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
514474
// - add a topological sort of IndirectCalls to ensure the best types knowledge
515475
// - we may need to fix function formal parameter types if they are opaque
516476
// pointers used as function pointers in these indirect calls
477+
// - defaulting to StorageClass::Function in the absence of the
478+
// SPV_INTEL_function_pointers extension seems wrong, as that might not be
479+
// able to hold a full width pointer to function, and it also does not model
480+
// the semantics of a pointer to function in a generic fashion.
517481
void SPIRVCallLowering::produceIndirectPtrTypes(
518482
MachineIRBuilder &MIRBuilder) const {
519483
// Create indirect call data types if any
520484
MachineFunction &MF = MIRBuilder.getMF();
485+
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
521486
for (auto const &IC : IndirectCalls) {
522487
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
523488
IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
@@ -535,8 +500,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
535500
SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
536501
FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
537502
// SPIR-V pointer to function type:
538-
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
539-
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
503+
auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
504+
? SPIRV::StorageClass::CodeSectionINTEL
505+
: SPIRV::StorageClass::Function;
506+
SPIRVType *IndirectFuncPtrTy =
507+
GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
540508
// Correct the Callee type
541509
GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
542510
}
@@ -564,12 +532,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
564532
// TODO: support constexpr casts and indirect calls.
565533
if (CF == nullptr)
566534
return false;
567-
if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
568-
OrigRetTy = FTy->getReturnType();
569-
if (isUntypedPointerTy(OrigRetTy)) {
570-
if (auto *DerivedRetTy = GR->findReturnType(CF))
571-
OrigRetTy = DerivedRetTy;
572-
}
535+
536+
FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF);
537+
OrigRetTy = FTy->getReturnType();
538+
if (isUntypedPointerTy(OrigRetTy)) {
539+
if (auto *DerivedRetTy = GR->findReturnType(CF))
540+
OrigRetTy = DerivedRetTy;
573541
}
574542
}
575543

@@ -691,11 +659,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
691659
if (CalleeReg.isValid()) {
692660
SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
693661
IndirectCall.Callee = CalleeReg;
694-
IndirectCall.RetTy = OrigRetTy;
695-
for (const auto &Arg : Info.OrigArgs) {
696-
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
697-
IndirectCall.ArgTys.push_back(Arg.Ty);
698-
IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
662+
FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB);
663+
IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
664+
assert(FTy->getNumParams() == Info.OrigArgs.size() &&
665+
"Function types mismatch");
666+
for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
667+
assert(Info.OrigArgs[I].Regs.size() == 1 &&
668+
"Call arg has multiple VRegs");
669+
IndirectCall.ArgTys.push_back(FTy->getParamType(I));
670+
IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]);
699671
}
700672
IndirectCalls.push_back(IndirectCall);
701673
}

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,17 @@ static void emitAssignName(Instruction *I, IRBuilder<> &B) {
360360
if (!I->hasName() || I->getType()->isAggregateType() ||
361361
expectIgnoredInIRTranslation(I))
362362
return;
363+
364+
if (isa<CallBase>(I)) {
365+
// TODO: this is a temporary workaround meant to prevent inserting internal
366+
// noise into the generated binary; remove once we rework the entire
367+
// aggregate removal machinery.
368+
StringRef Name = I->getName();
369+
if (Name.starts_with("spv.mutated_callsite"))
370+
return;
371+
if (Name.starts_with("spv.named_mutated_callsite"))
372+
I->setName(Name.substr(Name.rfind('.') + 1));
373+
}
363374
reportFatalOnTokenType(I);
364375
setInsertPointAfterDef(B, I);
365376
LLVMContext &Ctx = I->getContext();
@@ -759,10 +770,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
759770
if (Type *ElemTy = getPointeeType(KnownTy))
760771
maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8);
761772
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
762-
Ty = deduceElementTypeByValueDeep(
763-
Ref->getValueType(),
764-
Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
765-
UnknownElemTypeI8);
773+
if (auto *Fn = dyn_cast<Function>(Ref)) {
774+
Ty = SPIRV::getOriginalFunctionType(*Fn);
775+
GR->addDeducedElementType(I, Ty);
776+
} else {
777+
Ty = deduceElementTypeByValueDeep(
778+
Ref->getValueType(),
779+
Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
780+
UnknownElemTypeI8);
781+
}
766782
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
767783
Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
768784
UnknownElemTypeI8);
@@ -1063,10 +1079,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
10631079
if (!Op || !isPointerTy(Op->getType()))
10641080
return;
10651081
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
1066-
FunctionType *FTy = CI->getFunctionType();
1082+
FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
10671083
bool IsNewFTy = false, IsIncomplete = false;
10681084
SmallVector<Type *, 4> ArgTys;
1069-
for (Value *Arg : CI->args()) {
1085+
for (auto &&[ParmIdx, Arg] : llvm::enumerate(CI->args())) {
10701086
Type *ArgTy = Arg->getType();
10711087
if (ArgTy->isPointerTy()) {
10721088
if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
@@ -1077,6 +1093,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
10771093
} else {
10781094
IsIncomplete = true;
10791095
}
1096+
} else {
1097+
ArgTy = FTy->getFunctionParamType(ParmIdx);
10801098
}
10811099
ArgTys.push_back(ArgTy);
10821100
}

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
214214
if (Value *GlobalElem =
215215
Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
216216
ElementTy = findDeducedCompositeType(GlobalElem);
217+
else if (const Function *Fn = dyn_cast<Function>(Global))
218+
ElementTy = SPIRV::getOriginalFunctionType(*Fn);
217219
}
218220
return ElementTy ? ElementTy : Global->getValueType();
219221
}

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,12 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
257257
Register Def = MI.getOperand(0).getReg();
258258
Register Source = MI.getOperand(2).getReg();
259259
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
260-
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
261-
ElemTy, MI,
262-
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
260+
auto SC =
261+
isa<FunctionType>(ElemTy)
262+
? SPIRV::StorageClass::CodeSectionINTEL
263+
: addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
264+
SPIRVType *AssignedPtrType =
265+
GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);
263266

264267
// If the ptrcast would be redundant, replace all uses with the source
265268
// register.

llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include "llvm/Analysis/ValueTracking.h"
2727
#include "llvm/CodeGen/IntrinsicLowering.h"
2828
#include "llvm/IR/IRBuilder.h"
29+
#include "llvm/IR/InstIterator.h"
30+
#include "llvm/IR/Instructions.h"
2931
#include "llvm/IR/IntrinsicInst.h"
3032
#include "llvm/IR/Intrinsics.h"
3133
#include "llvm/IR/IntrinsicsSPIRV.h"
@@ -41,6 +43,7 @@ class SPIRVPrepareFunctions : public ModulePass {
4143
const SPIRVTargetMachine &TM;
4244
bool substituteIntrinsicCalls(Function *F);
4345
Function *removeAggregateTypesFromSignature(Function *F);
46+
bool removeAggregateTypesFromCalls(Function *F);
4447

4548
public:
4649
static char ID;
@@ -469,6 +472,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
469472
return Changed;
470473
}
471474

475+
static void
476+
addFunctionTypeMutation(NamedMDNode *NMD,
477+
SmallVector<std::pair<int, Type *>> ChangedTys,
478+
StringRef Name) {
479+
480+
LLVMContext &Ctx = NMD->getParent()->getContext();
481+
Type *I32Ty = IntegerType::getInt32Ty(Ctx);
482+
483+
SmallVector<Metadata *> MDArgs;
484+
MDArgs.push_back(MDString::get(Ctx, Name));
485+
transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
486+
return MDNode::get(
487+
Ctx, {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
488+
ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
489+
});
490+
NMD->addOperand(MDNode::get(Ctx, MDArgs));
491+
}
472492
// Returns F if aggregate argument/return types are not present or cloned F
473493
// function with the types replaced by i32 types. The change in types is
474494
// noted in 'spv.cloned_funcs' metadata for later restoration.
@@ -503,7 +523,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
503523
FunctionType *NewFTy =
504524
FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
505525
Function *NewF =
506-
Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
526+
Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(),
527+
F->getName(), F->getParent());
507528

508529
ValueToValueMapTy VMap;
509530
auto NewFArgIt = NewF->arg_begin();
@@ -518,22 +539,18 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
518539
Returns);
519540
NewF->takeName(F);
520541

521-
NamedMDNode *FuncMD =
522-
F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
523-
SmallVector<Metadata *, 2> MDArgs;
524-
MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
525-
for (auto &ChangedTyP : ChangedTypes)
526-
MDArgs.push_back(MDNode::get(
527-
B.getContext(),
528-
{ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
529-
ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
530-
MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
531-
FuncMD->addOperand(ThisFuncMD);
542+
addFunctionTypeMutation(
543+
NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"),
544+
std::move(ChangedTypes), NewF->getName());
532545

533546
for (auto *U : make_early_inc_range(F->users())) {
534-
if (auto *CI = dyn_cast<CallInst>(U))
547+
if (CallInst *CI;
548+
(CI = dyn_cast<CallInst>(U)) && CI->getCalledFunction() == F)
535549
CI->mutateFunctionType(NewF->getFunctionType());
536-
U->replaceUsesOfWith(F, NewF);
550+
if (auto *C = dyn_cast<Constant>(U))
551+
C->handleOperandChange(F, NewF);
552+
else
553+
U->replaceUsesOfWith(F, NewF);
537554
}
538555

539556
// register the mutation
@@ -543,11 +560,78 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
543560
return NewF;
544561
}
545562

563+
// Mutates indirect callsites iff if aggregate argument/return types are present
564+
// with the types replaced by i32 types. The change in types is noted in
565+
// 'spv.mutated_callsites' metadata for later restoration.
566+
bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
567+
if (F->isDeclaration() || F->isIntrinsic())
568+
return false;
569+
570+
SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
571+
for (auto &&I : instructions(F)) {
572+
if (auto *CB = dyn_cast<CallBase>(&I)) {
573+
if (!CB->getCalledOperand() || CB->getCalledFunction())
574+
continue;
575+
if (CB->getType()->isAggregateType() ||
576+
any_of(CB->args(),
577+
[](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
578+
Calls.emplace_back(CB, nullptr);
579+
}
580+
}
581+
582+
if (Calls.empty())
583+
return false;
584+
585+
IRBuilder<> B(F->getContext());
586+
587+
for (auto &&[CB, NewFnTy] : Calls) {
588+
SmallVector<std::pair<int, Type *>> ChangedTypes;
589+
SmallVector<Type *> NewArgTypes;
590+
591+
Type *RetTy = CB->getType();
592+
if (RetTy->isAggregateType()) {
593+
ChangedTypes.emplace_back(-1, RetTy);
594+
RetTy = B.getInt32Ty();
595+
}
596+
597+
for (auto &&Arg : CB->args()) {
598+
if (Arg->getType()->isAggregateType()) {
599+
NewArgTypes.push_back(B.getInt32Ty());
600+
ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType());
601+
} else {
602+
NewArgTypes.push_back(Arg->getType());
603+
}
604+
}
605+
NewFnTy = FunctionType::get(RetTy, NewArgTypes,
606+
CB->getFunctionType()->isVarArg());
607+
608+
if (!CB->hasName())
609+
CB->setName("spv.mutated_callsite." + F->getName());
610+
else
611+
CB->setName("spv.named_mutated_callsite." + F->getName() + "." +
612+
CB->getName());
613+
614+
addFunctionTypeMutation(
615+
F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
616+
std::move(ChangedTypes), CB->getName());
617+
}
618+
619+
for (auto &&[CB, NewFTy] : Calls) {
620+
if (NewFTy->getReturnType() != CB->getType())
621+
TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
622+
CB, CB->getType());
623+
CB->mutateFunctionType(NewFTy);
624+
}
625+
626+
return true;
627+
}
628+
546629
bool SPIRVPrepareFunctions::runOnModule(Module &M) {
547630
bool Changed = false;
548631
for (Function &F : M) {
549632
Changed |= substituteIntrinsicCalls(&F);
550633
Changed |= sortBlocks(F);
634+
Changed |= removeAggregateTypesFromCalls(&F);
551635
}
552636

553637
std::vector<Function *> FuncsWorklist;

0 commit comments

Comments
 (0)