31
31
#include " source/opt/build_module.h"
32
32
#include " source/opt/compact_ids_pass.h"
33
33
#include " source/opt/decoration_manager.h"
34
+ #include " source/opt/ir_builder.h"
34
35
#include " source/opt/ir_loader.h"
35
36
#include " source/opt/pass_manager.h"
36
37
#include " source/opt/remove_duplicates_pass.h"
@@ -46,12 +47,14 @@ namespace spvtools {
46
47
namespace {
47
48
48
49
using opt::Instruction;
50
+ using opt::InstructionBuilder;
49
51
using opt::IRContext;
50
52
using opt::Module;
51
53
using opt::PassManager;
52
54
using opt::RemoveDuplicatesPass;
53
55
using opt::analysis::DecorationManager;
54
56
using opt::analysis::DefUseManager;
57
+ using opt::analysis::Function;
55
58
using opt::analysis::Type;
56
59
using opt::analysis::TypeManager;
57
60
@@ -126,6 +129,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
126
129
// checked.
127
130
spv_result_t CheckImportExportCompatibility (const MessageConsumer& consumer,
128
131
const LinkageTable& linkings_to_do,
132
+ bool allow_ptr_type_mismatch,
129
133
opt::IRContext* context);
130
134
131
135
// Remove linkage specific instructions, such as prototypes of imported
@@ -502,6 +506,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
502
506
503
507
spv_result_t CheckImportExportCompatibility (const MessageConsumer& consumer,
504
508
const LinkageTable& linkings_to_do,
509
+ bool allow_ptr_type_mismatch,
505
510
opt::IRContext* context) {
506
511
spv_position_t position = {};
507
512
@@ -513,14 +518,42 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
513
518
type_manager.GetType (linking_entry.imported_symbol .type_id );
514
519
Type* exported_symbol_type =
515
520
type_manager.GetType (linking_entry.exported_symbol .type_id );
516
- if (!(*imported_symbol_type == *exported_symbol_type))
521
+ if (!(*imported_symbol_type == *exported_symbol_type)) {
522
+ Function* imported_symbol_type_func = imported_symbol_type->AsFunction ();
523
+ Function* exported_symbol_type_func = exported_symbol_type->AsFunction ();
524
+
525
+ if (imported_symbol_type_func && exported_symbol_type_func) {
526
+ const auto & imported_params = imported_symbol_type_func->param_types ();
527
+ const auto & exported_params = exported_symbol_type_func->param_types ();
528
+ // allow_ptr_type_mismatch allows linking functions where the pointer
529
+ // type of arguments doesn't match. Everything else still needs to be
530
+ // equal. This is to workaround LLVM-17+ not having typed pointers and
531
+ // generated SPIR-Vs not knowing the actual pointer types in some cases.
532
+ if (allow_ptr_type_mismatch &&
533
+ imported_params.size () == exported_params.size ()) {
534
+ bool correct = true ;
535
+ for (size_t i = 0 ; i < imported_params.size (); i++) {
536
+ const auto & imported_param = imported_params[i];
537
+ const auto & exported_param = exported_params[i];
538
+
539
+ if (!imported_param->IsSame (exported_param) &&
540
+ (imported_param->kind () != Type::kPointer ||
541
+ exported_param->kind () != Type::kPointer )) {
542
+ correct = false ;
543
+ break ;
544
+ }
545
+ }
546
+ if (correct) continue ;
547
+ }
548
+ }
517
549
return DiagnosticStream (position, consumer, " " , SPV_ERROR_INVALID_BINARY)
518
550
<< " Type mismatch on symbol \" "
519
551
<< linking_entry.imported_symbol .name
520
552
<< " \" between imported variable/function %"
521
553
<< linking_entry.imported_symbol .id
522
554
<< " and exported variable/function %"
523
555
<< linking_entry.exported_symbol .id << " ." ;
556
+ }
524
557
}
525
558
526
559
// Ensure the import and export decorations are similar
@@ -696,6 +729,57 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
696
729
return SPV_SUCCESS;
697
730
}
698
731
732
+ spv_result_t FixFunctionCallTypes (opt::IRContext& context,
733
+ const LinkageTable& linkings) {
734
+ auto mod = context.module ();
735
+ const auto type_manager = context.get_type_mgr ();
736
+ const auto def_use_mgr = context.get_def_use_mgr ();
737
+
738
+ for (auto & func : *mod) {
739
+ func.ForEachInst ([&](Instruction* inst) {
740
+ if (inst->opcode () != spv::Op::OpFunctionCall) return ;
741
+ opt::Operand& target = inst->GetInOperand (0 );
742
+
743
+ // only fix calls to imported functions
744
+ auto linking = std::find_if (
745
+ linkings.begin (), linkings.end (), [&](const auto & entry) {
746
+ return entry.exported_symbol .id == target.AsId ();
747
+ });
748
+ if (linking == linkings.end ()) return ;
749
+
750
+ auto builder = InstructionBuilder (&context, inst);
751
+ for (uint32_t i = 1 ; i < inst->NumInOperands (); ++i) {
752
+ auto exported_func_param =
753
+ def_use_mgr->GetDef (linking->exported_symbol .parameter_ids [i - 1 ]);
754
+ const Type* target_type =
755
+ type_manager->GetType (exported_func_param->type_id ());
756
+ if (target_type->kind () != Type::kPointer ) continue ;
757
+
758
+ opt::Operand& arg = inst->GetInOperand (i);
759
+ const Type* param_type =
760
+ type_manager->GetType (def_use_mgr->GetDef (arg.AsId ())->type_id ());
761
+
762
+ // No need to cast if it already matches
763
+ if (*param_type == *target_type) continue ;
764
+
765
+ auto new_id = context.TakeNextId ();
766
+
767
+ // cast to the expected pointer type
768
+ builder.AddInstruction (MakeUnique<opt::Instruction>(
769
+ &context, spv::Op::OpBitcast, exported_func_param->type_id (),
770
+ new_id,
771
+ opt::Instruction::OperandList (
772
+ {{SPV_OPERAND_TYPE_ID, {arg.AsId ()}}})));
773
+
774
+ inst->SetInOperand (i, {new_id});
775
+ }
776
+ });
777
+ }
778
+ context.InvalidateAnalyses (opt::IRContext::kAnalysisDefUse |
779
+ opt::IRContext::kAnalysisInstrToBlockMapping );
780
+ return SPV_SUCCESS;
781
+ }
782
+
699
783
} // namespace
700
784
701
785
spv_result_t Link (const Context& context,
@@ -773,26 +857,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
773
857
if (res != SPV_SUCCESS) return res;
774
858
}
775
859
776
- // Phase 4: Find the import/export pairs
860
+ // Phase 4: Remove duplicates
861
+ PassManager manager;
862
+ manager.SetMessageConsumer (consumer);
863
+ manager.AddPass <RemoveDuplicatesPass>();
864
+ opt::Pass::Status pass_res = manager.Run (&linked_context);
865
+ if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
866
+
867
+ // Phase 5: Find the import/export pairs
777
868
LinkageTable linkings_to_do;
778
869
res = GetImportExportPairs (consumer, linked_context,
779
870
*linked_context.get_def_use_mgr (),
780
871
*linked_context.get_decoration_mgr (),
781
872
options.GetAllowPartialLinkage (), &linkings_to_do);
782
873
if (res != SPV_SUCCESS) return res;
783
874
784
- // Phase 5: Ensure the import and export have the same types and decorations.
785
- res =
786
- CheckImportExportCompatibility (consumer, linkings_to_do, &linked_context);
875
+ // Phase 6: Ensure the import and export have the same types and decorations.
876
+ res = CheckImportExportCompatibility (consumer, linkings_to_do,
877
+ options.GetAllowPtrTypeMismatch (),
878
+ &linked_context);
787
879
if (res != SPV_SUCCESS) return res;
788
880
789
- // Phase 6: Remove duplicates
790
- PassManager manager;
791
- manager.SetMessageConsumer (consumer);
792
- manager.AddPass <RemoveDuplicatesPass>();
793
- opt::Pass::Status pass_res = manager.Run (&linked_context);
794
- if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
795
-
796
881
// Phase 7: Remove all names and decorations of import variables/functions
797
882
for (const auto & linking_entry : linkings_to_do) {
798
883
linked_context.KillNamesAndDecorates (linking_entry.imported_symbol .id );
@@ -815,21 +900,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
815
900
&linked_context);
816
901
if (res != SPV_SUCCESS) return res;
817
902
818
- // Phase 10: Compact the IDs used in the module
903
+ // Phase 10: Optionally fix function call types
904
+ if (options.GetAllowPtrTypeMismatch ()) {
905
+ res = FixFunctionCallTypes (linked_context, linkings_to_do);
906
+ if (res != SPV_SUCCESS) return res;
907
+ }
908
+
909
+ // Phase 11: Compact the IDs used in the module
819
910
manager.AddPass <opt::CompactIdsPass>();
820
911
pass_res = manager.Run (&linked_context);
821
912
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
822
913
823
- // Phase 11 : Recompute EntryPoint variables
914
+ // Phase 12 : Recompute EntryPoint variables
824
915
manager.AddPass <opt::RemoveUnusedInterfaceVariablesPass>();
825
916
pass_res = manager.Run (&linked_context);
826
917
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
827
918
828
- // Phase 12 : Warn if SPIR-V limits were exceeded
919
+ // Phase 13 : Warn if SPIR-V limits were exceeded
829
920
res = VerifyLimits (consumer, linked_context);
830
921
if (res != SPV_SUCCESS) return res;
831
922
832
- // Phase 13 : Output the module
923
+ // Phase 14 : Output the module
833
924
linked_context.module ()->ToBinary (linked_binary, true );
834
925
835
926
return SPV_SUCCESS;
0 commit comments