Skip to content

Commit e99a5c0

Browse files
authored
spirv-link: allow linking functions with different pointer arguments (KhronosGroup#5534)
* linker: run dedup earlier Otherwise `linkings_to_do` might end up with stale IDs. * linker: allow linking functions with different pointer arguments Since llvm-17 there are no typed pointers and hte SPIRV-LLVM-Translator doesn't know the function signature of imported functions. I'm investigating different ways of solving this problem and adding an option to work around it inside `spirv-link` is one of those. The code is almost complete, just I'm having troubles constructing the bitcast to cast the pointer parameters to the final type. Closes: KhronosGroup/SPIRV-LLVM-Translator#2153 * test/linker: add tests to test the AllowPtrTypeMismatch feature
1 parent ca37349 commit e99a5c0

File tree

4 files changed

+498
-49
lines changed

4 files changed

+498
-49
lines changed

include/spirv-tools/linker.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#define INCLUDE_SPIRV_TOOLS_LINKER_HPP_
1717

1818
#include <cstdint>
19-
2019
#include <memory>
2120
#include <vector>
2221

@@ -63,11 +62,17 @@ class SPIRV_TOOLS_EXPORT LinkerOptions {
6362
use_highest_version_ = use_highest_vers;
6463
}
6564

65+
bool GetAllowPtrTypeMismatch() const { return allow_ptr_type_mismatch_; }
66+
void SetAllowPtrTypeMismatch(bool allow_ptr_type_mismatch) {
67+
allow_ptr_type_mismatch_ = allow_ptr_type_mismatch;
68+
}
69+
6670
private:
6771
bool create_library_{false};
6872
bool verify_ids_{false};
6973
bool allow_partial_linkage_{false};
7074
bool use_highest_version_{false};
75+
bool allow_ptr_type_mismatch_{false};
7176
};
7277

7378
// Links one or more SPIR-V modules into a new SPIR-V module. That is, combine

source/link/linker.cpp

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "source/opt/build_module.h"
3232
#include "source/opt/compact_ids_pass.h"
3333
#include "source/opt/decoration_manager.h"
34+
#include "source/opt/ir_builder.h"
3435
#include "source/opt/ir_loader.h"
3536
#include "source/opt/pass_manager.h"
3637
#include "source/opt/remove_duplicates_pass.h"
@@ -46,12 +47,14 @@ namespace spvtools {
4647
namespace {
4748

4849
using opt::Instruction;
50+
using opt::InstructionBuilder;
4951
using opt::IRContext;
5052
using opt::Module;
5153
using opt::PassManager;
5254
using opt::RemoveDuplicatesPass;
5355
using opt::analysis::DecorationManager;
5456
using opt::analysis::DefUseManager;
57+
using opt::analysis::Function;
5558
using opt::analysis::Type;
5659
using opt::analysis::TypeManager;
5760

@@ -126,6 +129,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
126129
// checked.
127130
spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
128131
const LinkageTable& linkings_to_do,
132+
bool allow_ptr_type_mismatch,
129133
opt::IRContext* context);
130134

131135
// Remove linkage specific instructions, such as prototypes of imported
@@ -502,6 +506,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
502506

503507
spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
504508
const LinkageTable& linkings_to_do,
509+
bool allow_ptr_type_mismatch,
505510
opt::IRContext* context) {
506511
spv_position_t position = {};
507512

@@ -513,14 +518,42 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
513518
type_manager.GetType(linking_entry.imported_symbol.type_id);
514519
Type* exported_symbol_type =
515520
type_manager.GetType(linking_entry.exported_symbol.type_id);
516-
if (!(*imported_symbol_type == *exported_symbol_type))
521+
if (!(*imported_symbol_type == *exported_symbol_type)) {
522+
Function* imported_symbol_type_func = imported_symbol_type->AsFunction();
523+
Function* exported_symbol_type_func = exported_symbol_type->AsFunction();
524+
525+
if (imported_symbol_type_func && exported_symbol_type_func) {
526+
const auto& imported_params = imported_symbol_type_func->param_types();
527+
const auto& exported_params = exported_symbol_type_func->param_types();
528+
// allow_ptr_type_mismatch allows linking functions where the pointer
529+
// type of arguments doesn't match. Everything else still needs to be
530+
// equal. This is to workaround LLVM-17+ not having typed pointers and
531+
// generated SPIR-Vs not knowing the actual pointer types in some cases.
532+
if (allow_ptr_type_mismatch &&
533+
imported_params.size() == exported_params.size()) {
534+
bool correct = true;
535+
for (size_t i = 0; i < imported_params.size(); i++) {
536+
const auto& imported_param = imported_params[i];
537+
const auto& exported_param = exported_params[i];
538+
539+
if (!imported_param->IsSame(exported_param) &&
540+
(imported_param->kind() != Type::kPointer ||
541+
exported_param->kind() != Type::kPointer)) {
542+
correct = false;
543+
break;
544+
}
545+
}
546+
if (correct) continue;
547+
}
548+
}
517549
return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
518550
<< "Type mismatch on symbol \""
519551
<< linking_entry.imported_symbol.name
520552
<< "\" between imported variable/function %"
521553
<< linking_entry.imported_symbol.id
522554
<< " and exported variable/function %"
523555
<< linking_entry.exported_symbol.id << ".";
556+
}
524557
}
525558

526559
// Ensure the import and export decorations are similar
@@ -696,6 +729,57 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
696729
return SPV_SUCCESS;
697730
}
698731

732+
spv_result_t FixFunctionCallTypes(opt::IRContext& context,
733+
const LinkageTable& linkings) {
734+
auto mod = context.module();
735+
const auto type_manager = context.get_type_mgr();
736+
const auto def_use_mgr = context.get_def_use_mgr();
737+
738+
for (auto& func : *mod) {
739+
func.ForEachInst([&](Instruction* inst) {
740+
if (inst->opcode() != spv::Op::OpFunctionCall) return;
741+
opt::Operand& target = inst->GetInOperand(0);
742+
743+
// only fix calls to imported functions
744+
auto linking = std::find_if(
745+
linkings.begin(), linkings.end(), [&](const auto& entry) {
746+
return entry.exported_symbol.id == target.AsId();
747+
});
748+
if (linking == linkings.end()) return;
749+
750+
auto builder = InstructionBuilder(&context, inst);
751+
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
752+
auto exported_func_param =
753+
def_use_mgr->GetDef(linking->exported_symbol.parameter_ids[i - 1]);
754+
const Type* target_type =
755+
type_manager->GetType(exported_func_param->type_id());
756+
if (target_type->kind() != Type::kPointer) continue;
757+
758+
opt::Operand& arg = inst->GetInOperand(i);
759+
const Type* param_type =
760+
type_manager->GetType(def_use_mgr->GetDef(arg.AsId())->type_id());
761+
762+
// No need to cast if it already matches
763+
if (*param_type == *target_type) continue;
764+
765+
auto new_id = context.TakeNextId();
766+
767+
// cast to the expected pointer type
768+
builder.AddInstruction(MakeUnique<opt::Instruction>(
769+
&context, spv::Op::OpBitcast, exported_func_param->type_id(),
770+
new_id,
771+
opt::Instruction::OperandList(
772+
{{SPV_OPERAND_TYPE_ID, {arg.AsId()}}})));
773+
774+
inst->SetInOperand(i, {new_id});
775+
}
776+
});
777+
}
778+
context.InvalidateAnalyses(opt::IRContext::kAnalysisDefUse |
779+
opt::IRContext::kAnalysisInstrToBlockMapping);
780+
return SPV_SUCCESS;
781+
}
782+
699783
} // namespace
700784

701785
spv_result_t Link(const Context& context,
@@ -773,26 +857,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
773857
if (res != SPV_SUCCESS) return res;
774858
}
775859

776-
// Phase 4: Find the import/export pairs
860+
// Phase 4: Remove duplicates
861+
PassManager manager;
862+
manager.SetMessageConsumer(consumer);
863+
manager.AddPass<RemoveDuplicatesPass>();
864+
opt::Pass::Status pass_res = manager.Run(&linked_context);
865+
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
866+
867+
// Phase 5: Find the import/export pairs
777868
LinkageTable linkings_to_do;
778869
res = GetImportExportPairs(consumer, linked_context,
779870
*linked_context.get_def_use_mgr(),
780871
*linked_context.get_decoration_mgr(),
781872
options.GetAllowPartialLinkage(), &linkings_to_do);
782873
if (res != SPV_SUCCESS) return res;
783874

784-
// Phase 5: Ensure the import and export have the same types and decorations.
785-
res =
786-
CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context);
875+
// Phase 6: Ensure the import and export have the same types and decorations.
876+
res = CheckImportExportCompatibility(consumer, linkings_to_do,
877+
options.GetAllowPtrTypeMismatch(),
878+
&linked_context);
787879
if (res != SPV_SUCCESS) return res;
788880

789-
// Phase 6: Remove duplicates
790-
PassManager manager;
791-
manager.SetMessageConsumer(consumer);
792-
manager.AddPass<RemoveDuplicatesPass>();
793-
opt::Pass::Status pass_res = manager.Run(&linked_context);
794-
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
795-
796881
// Phase 7: Remove all names and decorations of import variables/functions
797882
for (const auto& linking_entry : linkings_to_do) {
798883
linked_context.KillNamesAndDecorates(linking_entry.imported_symbol.id);
@@ -815,21 +900,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
815900
&linked_context);
816901
if (res != SPV_SUCCESS) return res;
817902

818-
// Phase 10: Compact the IDs used in the module
903+
// Phase 10: Optionally fix function call types
904+
if (options.GetAllowPtrTypeMismatch()) {
905+
res = FixFunctionCallTypes(linked_context, linkings_to_do);
906+
if (res != SPV_SUCCESS) return res;
907+
}
908+
909+
// Phase 11: Compact the IDs used in the module
819910
manager.AddPass<opt::CompactIdsPass>();
820911
pass_res = manager.Run(&linked_context);
821912
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
822913

823-
// Phase 11: Recompute EntryPoint variables
914+
// Phase 12: Recompute EntryPoint variables
824915
manager.AddPass<opt::RemoveUnusedInterfaceVariablesPass>();
825916
pass_res = manager.Run(&linked_context);
826917
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
827918

828-
// Phase 12: Warn if SPIR-V limits were exceeded
919+
// Phase 13: Warn if SPIR-V limits were exceeded
829920
res = VerifyLimits(consumer, linked_context);
830921
if (res != SPV_SUCCESS) return res;
831922

832-
// Phase 13: Output the module
923+
// Phase 14: Output the module
833924
linked_context.module()->ToBinary(linked_binary, true);
834925

835926
return SPV_SUCCESS;

0 commit comments

Comments
 (0)