Skip to content
84 changes: 28 additions & 56 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,47 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
}

// This code restores function args/retvalue types for composite cases
// because the final types should still be aggregate whereas they're i32
// during the translation to cope with aggregate flattening etc.
static FunctionType *getOriginalFunctionType(const Function &F) {
auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
if (NamedMD == nullptr)
return F.getFunctionType();

Type *RetTy = F.getFunctionType()->getReturnType();
SmallVector<Type *, 4> ArgTypes;
for (auto &Arg : F.args())
ArgTypes.push_back(Arg.getType());

auto ThisFuncMDIt =
std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
return isa<MDString>(N->getOperand(0)) &&
cast<MDString>(N->getOperand(0))->getString() == F.getName();
});
if (ThisFuncMDIt != NamedMD->op_end()) {
auto *ThisFuncMD = *ThisFuncMDIt;
for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
assert(MD && "MDNode operand is expected");
ConstantInt *Const = getConstInt(MD, 0);
if (Const) {
auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
assert(CMeta && "ConstantAsMetadata operand is expected");
assert(Const->getSExtValue() >= -1);
// Currently -1 indicates return value, greater values mean
// argument numbers.
if (Const->getSExtValue() == -1)
RetTy = CMeta->getType();
else
ArgTypes[Const->getSExtValue()] = CMeta->getType();
}
}
}

return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
}

static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function &F, unsigned ArgIdx) {
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
Expand Down Expand Up @@ -204,7 +163,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, ArgIdx);

Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
Type *OriginalArgType =
SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx);

// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
// be legally reassigned later).
Expand Down Expand Up @@ -421,7 +381,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
FunctionType *FTy = getOriginalFunctionType(F);
FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
Type *FRetTy = FTy->getReturnType();
if (isUntypedPointerTy(FRetTy)) {
if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
Expand Down Expand Up @@ -506,10 +466,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// - add a topological sort of IndirectCalls to ensure the best types knowledge
// - we may need to fix function formal parameter types if they are opaque
// pointers used as function pointers in these indirect calls
// - defaulting to StorageClass::Function in the absence of the
// SPV_INTEL_function_pointers extension seems wrong, as that might not be
// able to hold a full width pointer to function, and it also does not model
// the semantics of a pointer to function in a generic fashion.
void SPIRVCallLowering::produceIndirectPtrTypes(
MachineIRBuilder &MIRBuilder) const {
// Create indirect call data types if any
MachineFunction &MF = MIRBuilder.getMF();
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
for (auto const &IC : IndirectCalls) {
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
Expand All @@ -527,8 +492,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
// SPIR-V pointer to function type:
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
? SPIRV::StorageClass::CodeSectionINTEL
: SPIRV::StorageClass::Function;
SPIRVType *IndirectFuncPtrTy =
GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
// Correct the Callee type
GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
}
Expand Down Expand Up @@ -556,12 +524,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
OrigRetTy = FTy->getReturnType();
if (isUntypedPointerTy(OrigRetTy)) {
if (auto *DerivedRetTy = GR->findReturnType(CF))
OrigRetTy = DerivedRetTy;
}

FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF);
OrigRetTy = FTy->getReturnType();
if (isUntypedPointerTy(OrigRetTy)) {
if (auto *DerivedRetTy = GR->findReturnType(CF))
OrigRetTy = DerivedRetTy;
}
}

Expand Down Expand Up @@ -683,11 +651,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (CalleeReg.isValid()) {
SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
IndirectCall.Callee = CalleeReg;
IndirectCall.RetTy = OrigRetTy;
for (const auto &Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
IndirectCall.ArgTys.push_back(Arg.Ty);
IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB);
IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
assert(FTy->getNumParams() == Info.OrigArgs.size() &&
"Function types mismatch");
for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
assert(Info.OrigArgs[I].Regs.size() == 1 &&
"Call arg has multiple VRegs");
IndirectCall.ArgTys.push_back(FTy->getParamType(I));
IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]);
}
IndirectCalls.push_back(IndirectCall);
}
Expand Down
30 changes: 24 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,17 @@ static void emitAssignName(Instruction *I, IRBuilder<> &B) {
if (!I->hasName() || I->getType()->isAggregateType() ||
expectIgnoredInIRTranslation(I))
return;

if (isa<CallBase>(I)) {
// TODO: this is a temporary workaround meant to prevent inserting internal
// noise into the generated binary; remove once we rework the entire
// aggregate removal machinery.
StringRef Name = I->getName();
if (Name.starts_with("spv.mutated_callsite"))
return;
if (Name.starts_with("spv.named_mutated_callsite"))
I->setName(Name.substr(Name.rfind('.') + 1));
}
reportFatalOnTokenType(I);
setInsertPointAfterDef(B, I);
LLVMContext &Ctx = I->getContext();
Expand Down Expand Up @@ -759,10 +770,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (Type *ElemTy = getPointeeType(KnownTy))
maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8);
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
Ty = deduceElementTypeByValueDeep(
Ref->getValueType(),
Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
UnknownElemTypeI8);
if (auto *Fn = dyn_cast<Function>(Ref)) {
Ty = SPIRV::getOriginalFunctionType(*Fn);
GR->addDeducedElementType(I, Ty);
} else {
Ty = deduceElementTypeByValueDeep(
Ref->getValueType(),
Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
UnknownElemTypeI8);
}
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
UnknownElemTypeI8);
Expand Down Expand Up @@ -1062,10 +1078,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
if (!Op || !isPointerTy(Op->getType()))
return;
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
FunctionType *FTy = CI->getFunctionType();
FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
bool IsNewFTy = false, IsIncomplete = false;
SmallVector<Type *, 4> ArgTys;
for (Value *Arg : CI->args()) {
for (auto &&[ParmIdx, Arg] : llvm::enumerate(CI->args())) {
Type *ArgTy = Arg->getType();
if (ArgTy->isPointerTy()) {
if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
Expand All @@ -1076,6 +1092,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
} else {
IsIncomplete = true;
}
} else {
ArgTy = FTy->getFunctionParamType(ParmIdx);
}
ArgTys.push_back(ArgTy);
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
if (Value *GlobalElem =
Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
ElementTy = findDeducedCompositeType(GlobalElem);
else if (const Function *Fn = dyn_cast<Function>(Global))
ElementTy = SPIRV::getOriginalFunctionType(*Fn);
}
return ElementTy ? ElementTy : Global->getValueType();
}
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,12 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Def = MI.getOperand(0).getReg();
Register Source = MI.getOperand(2).getReg();
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
ElemTy, MI,
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
auto SC =
isa<FunctionType>(ElemTy)
? SPIRV::StorageClass::CodeSectionINTEL
: addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
SPIRVType *AssignedPtrType =
GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);

// If the ptrcast would be redundant, replace all uses with the source
// register.
Expand Down
109 changes: 96 additions & 13 deletions llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
Expand All @@ -41,6 +43,7 @@ class SPIRVPrepareFunctions : public ModulePass {
const SPIRVTargetMachine &TM;
bool substituteIntrinsicCalls(Function *F);
Function *removeAggregateTypesFromSignature(Function *F);
bool removeAggregateTypesFromCalls(Function *F);

public:
static char ID;
Expand Down Expand Up @@ -469,6 +472,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
return Changed;
}

static void
addFunctionTypeMutation(NamedMDNode *NMD,
SmallVector<std::pair<int, Type *>> ChangedTys,
StringRef Name) {

LLVMContext &Ctx = NMD->getParent()->getContext();
Type *I32Ty = IntegerType::getInt32Ty(Ctx);

SmallVector<Metadata *> MDArgs;
MDArgs.push_back(MDString::get(Ctx, Name));
transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
return MDNode::get(
Ctx, {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
});
NMD->addOperand(MDNode::get(Ctx, MDArgs));
}
// Returns F if aggregate argument/return types are not present or cloned F
// function with the types replaced by i32 types. The change in types is
// noted in 'spv.cloned_funcs' metadata for later restoration.
Expand Down Expand Up @@ -503,7 +523,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
FunctionType *NewFTy =
FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
Function *NewF =
Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(),
F->getName(), F->getParent());

ValueToValueMapTy VMap;
auto NewFArgIt = NewF->arg_begin();
Expand All @@ -518,22 +539,17 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
Returns);
NewF->takeName(F);

NamedMDNode *FuncMD =
F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
SmallVector<Metadata *, 2> MDArgs;
MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
for (auto &ChangedTyP : ChangedTypes)
MDArgs.push_back(MDNode::get(
B.getContext(),
{ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
FuncMD->addOperand(ThisFuncMD);
addFunctionTypeMutation(
NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"),
std::move(ChangedTypes), NewF->getName());

for (auto *U : make_early_inc_range(F->users())) {
if (auto *CI = dyn_cast<CallInst>(U))
CI->mutateFunctionType(NewF->getFunctionType());
U->replaceUsesOfWith(F, NewF);
if (auto *C = dyn_cast<Constant>(U))
C->handleOperandChange(F, NewF);
else
U->replaceUsesOfWith(F, NewF);
}
Comment on lines 546 to 554
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've stumbled with the Constant failure too :)

  • mutateFunctionType doesn't handle the case where F is not the called function; should we assert / error instead ?
  • In my version I used replaceAllUsesWith which already does the call to handleOperandChange
  for (auto *U : F->users()) {
      auto *CI = dyn_cast<CallInst>(U);
      if (CI && CI->getCalledFunction() == F)
        CI->mutateFunctionType(NewF->getFunctionType());
  }
  F->replaceAllUsesWith(NewF);

Copy link
Contributor Author

@AlexVlx AlexVlx Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what you mean by the first bullet. I think we're interested in mutating the call site's type. Could you say more about your concern?

Ok, I believe I understand the concern, this'll do bad things if F is not the callee, but just an argument to the call (which'd still be an Use). I don't think this is an assertion / error, I'll just adopt the check.

As for using RAUW, I don't quite think we can do that, since it does assert(New->getType() == getType() && "replaceAllUses of value with new value of different type!");. Otherwise stated, it expects the type of the new value to match the type of the old, which is not quite what is happening here, I don' think. We are using the riskier forms because we're using a value with new, aggregate free type.


// register the mutation
Expand All @@ -543,11 +559,78 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
return NewF;
}

// Mutates indirect callsites iff if aggregate argument/return types are present
// with the types replaced by i32 types. The change in types is noted in
// 'spv.mutated_callsites' metadata for later restoration.
bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
if (F->isDeclaration() || F->isIntrinsic())
return false;

SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
for (auto &&I : instructions(F)) {
if (auto *CB = dyn_cast<CallBase>(&I)) {
if (!CB->getCalledOperand() || CB->getCalledFunction())
continue;
if (CB->getType()->isAggregateType() ||
any_of(CB->args(),
[](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
Calls.emplace_back(CB, nullptr);
}
}

if (Calls.empty())
return false;

IRBuilder<> B(F->getContext());

for (auto &&[CB, NewFnTy] : Calls) {
SmallVector<std::pair<int, Type *>> ChangedTypes;
SmallVector<Type *> NewArgTypes;

Type *RetTy = CB->getType();
if (RetTy->isAggregateType()) {
ChangedTypes.emplace_back(-1, RetTy);
RetTy = B.getInt32Ty();
}

for (auto &&Arg : CB->args()) {
if (Arg->getType()->isAggregateType()) {
NewArgTypes.push_back(B.getInt32Ty());
ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType());
} else {
NewArgTypes.push_back(Arg->getType());
}
}
NewFnTy = FunctionType::get(RetTy, NewArgTypes,
CB->getFunctionType()->isVarArg());

if (!CB->hasName())
CB->setName("spv.mutated_callsite." + F->getName());
else
CB->setName("spv.named_mutated_callsite." + F->getName() + "." +
CB->getName());

addFunctionTypeMutation(
F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
std::move(ChangedTypes), CB->getName());
}

for (auto &&[CB, NewFTy] : Calls) {
if (NewFTy->getReturnType() != CB->getType())
TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
CB, CB->getType());
CB->mutateFunctionType(NewFTy);
}

return true;
}

bool SPIRVPrepareFunctions::runOnModule(Module &M) {
bool Changed = false;
for (Function &F : M) {
Changed |= substituteIntrinsicCalls(&F);
Changed |= sortBlocks(F);
Changed |= removeAggregateTypesFromCalls(&F);
}

std::vector<Function *> FuncsWorklist;
Expand Down
Loading