Skip to content

Conversation

@AlexVlx
Copy link
Contributor

@AlexVlx AlexVlx commented Nov 26, 2025

This patch does two things:

  1. it extends the aggregate arg / ret replacement transform to work on indirect calls / pointers to function. It is somewhat spread out as retrieving the original function type is needed in a few places. In general, we should rethink / rework the entire infrastructure around aggregate arg/ret handling, using an opaque target specific type rather than i32;
  2. it enables global variables of pointer to function type, and, more specifically, global variables of a aggregate type (arrays / structures) with pointer to function elements.

This also exposes some issues in how we handle pointers to function and lowering indirect function calls, primarily around not using the program address space. These will be handled in a subsequent patch as they'll require somewhat more intrusive surgery, possibly involving modifying the data layout.

…nts. Start reworking indirect function calls.
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Alex Voicu (AlexVlx)

Changes

This patch does two things:

  1. it extends the aggregate arg / ret replacement transform to work on indirect calls / pointers to function. It is somewhat spread out as retrieving the original function type is needed in a few places. In general, we should rethink / rework the entire infrastructure around aggregate arg/ret handling, using an opaque target specific type rather than i32;
  2. it enables global variables of pointer to function type, and, more specifically, global variables of a aggregate type (arrays / structures) with pointer to function elements.

This also exposes some issues in how we handle pointers to function and lowering indirect function calls, primarily around not using the program address space. These will be handled in a subsequent patch as they'll require somewhat more intrusive surgery, possibly involving modifying the data layout.


Patch is 27.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169595.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+28-55)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+19-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+5-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (+93-13)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+64)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+5)
  • (added) llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll (+108)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index dd57b74d79a5e..6b5c602d4ac93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -131,46 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
 }
 
-// This code restores function args/retvalue types for composite cases
-// because the final types should still be aggregate whereas they're i32
-// during the translation to cope with aggregate flattening etc.
-static FunctionType *getOriginalFunctionType(const Function &F) {
-  auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
-  if (NamedMD == nullptr)
-    return F.getFunctionType();
-
-  Type *RetTy = F.getFunctionType()->getReturnType();
-  SmallVector<Type *, 4> ArgTypes;
-  for (auto &Arg : F.args())
-    ArgTypes.push_back(Arg.getType());
-
-  auto ThisFuncMDIt =
-      std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
-        return isa<MDString>(N->getOperand(0)) &&
-               cast<MDString>(N->getOperand(0))->getString() == F.getName();
-      });
-  if (ThisFuncMDIt != NamedMD->op_end()) {
-    auto *ThisFuncMD = *ThisFuncMDIt;
-    for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
-      MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
-      assert(MD && "MDNode operand is expected");
-      ConstantInt *Const = getConstInt(MD, 0);
-      if (Const) {
-        auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
-        assert(CMeta && "ConstantAsMetadata operand is expected");
-        assert(Const->getSExtValue() >= -1);
-        // Currently -1 indicates return value, greater values mean
-        // argument numbers.
-        if (Const->getSExtValue() == -1)
-          RetTy = CMeta->getType();
-        else
-          ArgTypes[Const->getSExtValue()] = CMeta->getType();
-      }
-    }
-  }
-
-  return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
-}
 
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
@@ -204,7 +164,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
 
-  Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
+  Type *OriginalArgType =
+      SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx);
 
   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
   // be legally reassigned later).
@@ -421,7 +382,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   auto MRI = MIRBuilder.getMRI();
   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
   MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
-  FunctionType *FTy = getOriginalFunctionType(F);
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(F);
   Type *FRetTy = FTy->getReturnType();
   if (isUntypedPointerTy(FRetTy)) {
     if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
@@ -506,10 +467,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
 // - add a topological sort of IndirectCalls to ensure the best types knowledge
 // - we may need to fix function formal parameter types if they are opaque
 //   pointers used as function pointers in these indirect calls
+// - defaulting to StorageClass::Function in the absence of the
+//   SPV_INTEL_function_pointers extension seems wrong, as that might not be
+//   able to hold a full width pointer to function, and it also does not model
+//   the semantics of a pointer to function in a generic fashion.
 void SPIRVCallLowering::produceIndirectPtrTypes(
     MachineIRBuilder &MIRBuilder) const {
   // Create indirect call data types if any
   MachineFunction &MF = MIRBuilder.getMF();
+  const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
   for (auto const &IC : IndirectCalls) {
     SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
         IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
@@ -527,8 +493,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
     SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
     // SPIR-V pointer to function type:
-    SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
-        SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+    auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
+        ? SPIRV::StorageClass::CodeSectionINTEL
+        : SPIRV::StorageClass::Function;
+    SPIRVType *IndirectFuncPtrTy =
+        GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC);
     // Correct the Callee type
     GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
   }
@@ -556,12 +525,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
-      OrigRetTy = FTy->getReturnType();
-      if (isUntypedPointerTy(OrigRetTy)) {
-        if (auto *DerivedRetTy = GR->findReturnType(CF))
-          OrigRetTy = DerivedRetTy;
-      }
+
+    FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF);
+    OrigRetTy = FTy->getReturnType();
+    if (isUntypedPointerTy(OrigRetTy)) {
+      if (auto *DerivedRetTy = GR->findReturnType(CF))
+        OrigRetTy = DerivedRetTy;
     }
   }
 
@@ -683,11 +652,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     if (CalleeReg.isValid()) {
       SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
       IndirectCall.Callee = CalleeReg;
-      IndirectCall.RetTy = OrigRetTy;
-      for (const auto &Arg : Info.OrigArgs) {
-        assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
-        IndirectCall.ArgTys.push_back(Arg.Ty);
-        IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+      FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB);
+      IndirectCall.RetTy = OrigRetTy = FTy->getReturnType();
+      assert(FTy->getNumParams() == Info.OrigArgs.size() &&
+             "Function types mismatch");
+      for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) {
+        assert(Info.OrigArgs[I].Regs.size() == 1 &&
+               "Call arg has multiple VRegs");
+        IndirectCall.ArgTys.push_back(FTy->getParamType(I));
+        IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]);
       }
       IndirectCalls.push_back(IndirectCall);
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8e14fb03127fc..b37e6d8ce4ea3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -358,7 +358,11 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 
 static void emitAssignName(Instruction *I, IRBuilder<> &B) {
   if (!I->hasName() || I->getType()->isAggregateType() ||
-      expectIgnoredInIRTranslation(I))
+      expectIgnoredInIRTranslation(I) ||
+      // TODO: this is a temporary workaround meant to prevent inserting
+      //       internal noise into the generated binary; remove once we rework
+      //       the entire aggregate removal machinery.
+      I->getName().starts_with("spv.mutated_callsite"))
     return;
   reportFatalOnTokenType(I);
   setInsertPointAfterDef(B, I);
@@ -759,10 +763,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     if (Type *ElemTy = getPointeeType(KnownTy))
       maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
-    Ty = deduceElementTypeByValueDeep(
-        Ref->getValueType(),
-        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
-        UnknownElemTypeI8);
+    if (auto *Fn = dyn_cast<Function>(Ref)) {
+      Ty = SPIRV::getOriginalFunctionType(*Fn);
+      GR->addDeducedElementType(I, Ty);
+    } else {
+      Ty = deduceElementTypeByValueDeep(
+          Ref->getValueType(),
+          Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
+          UnknownElemTypeI8);
+    }
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
                                           UnknownElemTypeI8);
@@ -1062,9 +1071,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
-  FunctionType *FTy = CI->getFunctionType();
+  FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI);
   bool IsNewFTy = false, IsIncomplete = false;
   SmallVector<Type *, 4> ArgTys;
+  unsigned ParmIdx = 0;
   for (Value *Arg : CI->args()) {
     Type *ArgTy = Arg->getType();
     if (ArgTy->isPointerTy()) {
@@ -1076,8 +1086,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       } else {
         IsIncomplete = true;
       }
+    } else {
+      ArgTy = FTy->getFunctionParamType(ParmIdx);
     }
     ArgTys.push_back(ArgTy);
+    ++ParmIdx;
   }
   Type *RetTy = FTy->getReturnType();
   if (CI->getType()->isPointerTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 09c77f0cfd4f5..16f3260bf4ffc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
       if (Value *GlobalElem =
               Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
         ElementTy = findDeducedCompositeType(GlobalElem);
+      else if (const Function *Fn = dyn_cast<Function>(Global))
+        ElementTy = SPIRV::getOriginalFunctionType(*Fn);
     }
     return ElementTy ? ElementTy : Global->getValueType();
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 0f4b3d59b904a..b273599596a35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -257,9 +257,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Def = MI.getOperand(0).getReg();
       Register Source = MI.getOperand(2).getReg();
       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
-      SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
-          ElemTy, MI,
-          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
+      auto SC = isa<FunctionType>(ElemTy)
+          ? SPIRV::StorageClass::CodeSectionINTEL
+          : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST);
+      SPIRVType *AssignedPtrType =
+          GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC);
 
       // If the ptrcast would be redundant, replace all uses with the source
       // register.
diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
index be88f334d2171..8fd261cfff25b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
@@ -41,6 +42,7 @@ class SPIRVPrepareFunctions : public ModulePass {
   const SPIRVTargetMachine &TM;
   bool substituteIntrinsicCalls(Function *F);
   Function *removeAggregateTypesFromSignature(Function *F);
+  bool removeAggregateTypesFromCalls(Function *F);
 
 public:
   static char ID;
@@ -469,6 +471,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
   return Changed;
 }
 
+static void addFunctionTypeMutation(
+    NamedMDNode *NMD,
+    SmallVector<std::pair<int, Type *>> ChangedTys, StringRef Name) {
+
+    LLVMContext &Ctx = NMD->getParent()->getContext();
+    Type *I32Ty = IntegerType::getInt32Ty(Ctx);
+
+    SmallVector<Metadata *> MDArgs;
+    MDArgs.push_back(MDString::get(Ctx, Name));
+    transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) {
+      return MDNode::get(
+          Ctx,
+          {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)),
+           ValueAsMetadata::get(Constant::getNullValue(CTy.second))});
+    });
+    NMD->addOperand(MDNode::get(Ctx, MDArgs));
+}
 // Returns F if aggregate argument/return types are not present or cloned F
 // function with the types replaced by i32 types. The change in types is
 // noted in 'spv.cloned_funcs' metadata for later restoration.
@@ -503,7 +522,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   FunctionType *NewFTy =
       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
   Function *NewF =
-      Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
+      Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(),
+                       F->getName(), F->getParent());
 
   ValueToValueMapTy VMap;
   auto NewFArgIt = NewF->arg_begin();
@@ -518,22 +538,17 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
                     Returns);
   NewF->takeName(F);
 
-  NamedMDNode *FuncMD =
-      F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
-  SmallVector<Metadata *, 2> MDArgs;
-  MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
-  for (auto &ChangedTyP : ChangedTypes)
-    MDArgs.push_back(MDNode::get(
-        B.getContext(),
-        {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
-         ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
-  MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
-  FuncMD->addOperand(ThisFuncMD);
+  addFunctionTypeMutation(
+      NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"),
+      std::move(ChangedTypes), NewF->getName());
 
   for (auto *U : make_early_inc_range(F->users())) {
     if (auto *CI = dyn_cast<CallInst>(U))
       CI->mutateFunctionType(NewF->getFunctionType());
-    U->replaceUsesOfWith(F, NewF);
+    if (auto *C = dyn_cast<Constant>(U))
+      C->handleOperandChange(F, NewF);
+    else
+      U->replaceUsesOfWith(F, NewF);
   }
 
   // register the mutation
@@ -543,11 +558,76 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
   return NewF;
 }
 
+// Mutates indirect callsites iff if aggregate argument/return types are present
+// with the types replaced by i32 types. The change in types is noted in
+// 'spv.mutated_callsites' metadata for later restoration.
+bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
+  if (F->isDeclaration() || F->isIntrinsic())
+    return false;
+
+  SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
+  for (auto &&BB : *F) {
+    for (auto &&I : BB) {
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        if (!CB->getCalledOperand() || CB->getCalledFunction())
+          continue;
+        if (CB->getType()->isAggregateType() ||
+            any_of(CB->args(),
+                  [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
+          Calls.emplace_back(CB, nullptr);
+      }
+    }
+  }
+
+  if (Calls.empty())
+    return false;
+
+  IRBuilder<> B(F->getContext());
+
+  for (auto &&[CB, NewFnTy] : Calls) {
+    SmallVector<std::pair<int, Type *>> ChangedTypes;
+    SmallVector<Type *> NewArgTypes;
+
+    if (CB->getType()->isAggregateType())
+      ChangedTypes.emplace_back(-1, CB->getType());
+
+    Type *RetTy = ChangedTypes.empty() ? CB->getType() : B.getInt32Ty();
+    for (auto &&Arg : CB->args()) {
+      if (Arg->getType()->isAggregateType()) {
+        NewArgTypes.push_back(B.getInt32Ty());
+        ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType());
+      } else {
+        NewArgTypes.push_back(Arg->getType());
+      }
+    }
+    NewFnTy = FunctionType::get(RetTy, NewArgTypes,
+                                CB->getFunctionType()->isVarArg());
+
+    if (!CB->hasName())
+      CB->setName("spv.mutated_callsite");
+
+    addFunctionTypeMutation(
+      F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"),
+      std::move(ChangedTypes),
+      CB->getName());
+  }
+
+  for (auto &&[CB, NewFTy] : Calls) {
+    if (NewFTy->getReturnType() != CB->getType())
+      TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
+          CB, CB->getType());
+    CB->mutateFunctionType(NewFTy);
+  }
+
+  return true;
+}
+
 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
   bool Changed = false;
   for (Function &F : M) {
     Changed |= substituteIntrinsicCalls(&F);
     Changed |= sortBlocks(F);
+    Changed |= removeAggregateTypesFromCalls(&F);
   }
 
   std::vector<Function *> FuncsWorklist;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 8f2fc01da476f..5b6a36de6c526 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -28,6 +28,70 @@
 #include <vector>
 
 namespace llvm {
+namespace SPIRV {
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+// TODO: should these just return nullptr when there's no metadata?
+static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
+                                                     FunctionType *FTy,
+                                                     StringRef Name) {
+  if (!NMD)
+    return FTy;
+
+  constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
+    if (MD->getNumOperands() <= OpId)
+      return nullptr;
+    if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(OpId)))
+      return dyn_cast<ConstantInt>(CMeta->getValue());
+    return nullptr;
+  };
+
+  auto It = find_if(NMD->operands(), [Name](MDNode *N) {
+    if (auto *MDS = dyn_cast_or_null<MDString>(N->getOperand(0)))
+      return MDS->getString() == Name;
+    return false;
+  });
+
+  if (It == NMD->op_end())
+    return FTy;
+
+  Type *RetTy = FTy->getReturnType();
+  SmallVector<Type *, 4> PTys(FTy->params());
+
+  for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
+    MDNode *MD = dyn_cast<MDNode>((*It)->getOperand(I));
+    assert(MD && "MDNode operand is expected");
+
+    if (auto *Const = getConstInt(MD, 0)) {
+      auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+      assert(CMeta && "ConstantAsMetadata operand is expected");
+      assert(Const->getSExtValue() >= -1);
+      // Currently -1 indicates return value, greater values mean
+      // argument numbers.
+      if (Const->getSExtValue() == -1)
+        RetTy = CMeta->getType();
+      else
+        PTys[Const->getSExtValue()] = CMeta->getType();
+    }
+  }
+
+  return FunctionType::get(RetTy, PTys, FTy->isVarArg());
+}
+
+FunctionType *getOriginalFunctionType(const Function &F) {
+  return extractFunctionTypeFromMetadata(
+      F.getParent()->getNamedMetadata("spv.cloned_funcs"),
+      F.getFunctionType(), F.getName());
+}
+
+FunctionType *getOriginalFunctionType(const CallBase &CB) {
+  return extractFunctionTypeFromMetadata(
+      CB.getParent()
+        ->getParent()->getParent()->getNamedMetadata("spv.mutated_callsites"),
+      CB.getFunctionType(), CB.getName());
+}
+} // Namespace SPIRV
 
 // The following functions are used to add these string literals as a series of
 // 32-bit integer operands with the correct format, and unpack them if necessary
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 99d9d403ea70c..3da77f1d6c1ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -159,6 +159,11 @@ struct FPFastMathDefaultInfoVector
   }
 };
 
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with a...
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@github-actions
Copy link

github-actions bot commented Nov 26, 2025

✅ With the latest revision this PR passed the undef deprecator.

@github-actions
Copy link

github-actions bot commented Nov 26, 2025

🐧 Linux x64 Test Results

  • 186963 tests passed
  • 4919 tests skipped

✅ The build succeeded and all tests passed.

Copy link
Contributor

@jmmartinez jmmartinez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a first go at it. I have spotted only 1 issue, the rest are cosmetic.


static void addFunctionTypeMutation(
NamedMDNode *NMD,
SmallVector<std::pair<int, Type *>> ChangedTys, StringRef Name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the interest of passing ChangedTys by doing an std::move in here vs a const& ? (maybe the explanation is in the follow-up PR ?)

Copy link
Contributor

@jmmartinez jmmartinez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like for somebody else to also take a look and validate before merging, to see if somebody has a better idea than the associate by name (as long as we get rid of it later it's ok for me).

Copy link
Contributor Author

@AlexVlx AlexVlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea: Could we use ValueAsMetadata in here instead of the name?

Potentially, but I wanted to more or less retain the existing code / minimise "innovation", so I didn't think profoundly re-factoring in this change:)

@jmmartinez
Copy link
Contributor

Just an idea: Could we use ValueAsMetadata in here instead of the name?

Potentially, but I wanted to more or less retain the existing code / minimise "innovation", so I didn't think profoundly re-factoring in this change:)

No problem!

Comment on lines 546 to 553
for (auto *U : make_early_inc_range(F->users())) {
if (auto *CI = dyn_cast<CallInst>(U))
CI->mutateFunctionType(NewF->getFunctionType());
U->replaceUsesOfWith(F, NewF);
if (auto *C = dyn_cast<Constant>(U))
C->handleOperandChange(F, NewF);
else
U->replaceUsesOfWith(F, NewF);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've stumbled with the Constant failure too :)

  • mutateFunctionType doesn't handle the case where F is not the called function; should we assert / error instead ?
  • In my version I used replaceAllUsesWith which already does the call to handleOperandChange
  for (auto *U : F->users()) {
      auto *CI = dyn_cast<CallInst>(U);
      if (CI && CI->getCalledFunction() == F)
        CI->mutateFunctionType(NewF->getFunctionType());
  }
  F->replaceAllUsesWith(NewF);

Copy link
Contributor Author

@AlexVlx AlexVlx Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what you mean by the first bullet. I think we're interested in mutating the call site's type. Could you say more about your concern?

Ok, I believe I understand the concern, this'll do bad things if F is not the callee, but just an argument to the call (which'd still be an Use). I don't think this is an assertion / error, I'll just adopt the check.

As for using RAUW, I don't quite think we can do that, since it does assert(New->getType() == getType() && "replaceAllUses of value with new value of different type!");. Otherwise stated, it expects the type of the new value to match the type of the old, which is not quite what is happening here, I don' think. We are using the riskier forms because we're using a value with new, aggregate free type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants