diff --git a/examples/functions.cpp b/examples/functions.cpp index fb5ceec..bf19a1d 100644 --- a/examples/functions.cpp +++ b/examples/functions.cpp @@ -353,6 +353,20 @@ JLCXX_MODULE init_test_module(jlcxx::Module& mod) { b = !b; }); + + mod.method("test_safe_cfunction_uint", [](jlcxx::SafeCFunction f_data) + { + auto f = jlcxx::make_function_pointer(f_data); + unsigned int buffer[] = {1,2,3}; + return f(buffer, 3); + }); + + mod.method("test_safe_cfunction_uint64", [](jlcxx::SafeCFunction f_data) + { + auto f = jlcxx::make_function_pointer(f_data); + uint64_t* buffer = f(3); + return buffer[0] + buffer[1] + buffer[2]; + }); } } diff --git a/include/jlcxx/functions.hpp b/include/jlcxx/functions.hpp index 9357e13..695c0d1 100644 --- a/include/jlcxx/functions.hpp +++ b/include/jlcxx/functions.hpp @@ -137,6 +137,29 @@ struct ConvertToCpp namespace detail { + // Allow fundamental pointer types to be passed as e.g. Ptr{Int32} instead of CxxPtr{Int32} + template + struct FundamentalPtrT + { + static jl_datatype_t* value() + { + return julia_type(); + } + }; + + template + struct FundamentalPtrT + { + static jl_datatype_t* value() + { + if constexpr (std::is_fundamental_v) + { + return (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)julia_type()); + } + return julia_type(); + } + }; + template struct SplitSignature; @@ -146,11 +169,28 @@ namespace detail typedef R return_type; typedef R(*fptr_t)(ArgsT...); - std::vector operator()() + jl_datatype_t* expected_return_type() { + create_if_not_exists(); + return julia_type(); + } + + jl_datatype_t* fundamental_ptr_return_type() + { + return FundamentalPtrT::value(); + } + + std::vector arg_types() + { + (create_if_not_exists(), ...); return std::vector({julia_type()...}); } + std::vector fundamental_ptr_types() + { + return std::vector({FundamentalPtrT::value()...}); + } + fptr_t cast_ptr(void* ptr) { return reinterpret_cast(ptr); @@ -166,15 +206,16 @@ typename detail::SplitSignature::fptr_t make_function_pointer(SafeCF JL_GC_PUSH3(&data.fptr, &data.return_type, &data.argtypes); // Check return type - jl_datatype_t* expected_rt = julia_type(); - if(expected_rt != data.return_type) + jl_datatype_t* expected_rt = SplitterT().expected_return_type(); + if(expected_rt != data.return_type && SplitterT().fundamental_ptr_return_type() != data.return_type) { JL_GC_POP(); throw std::runtime_error("Incorrect datatype for cfunction return type, expected " + julia_type_name(expected_rt) + " but got " + julia_type_name(data.return_type)); } // Check arguments - const std::vector expected_argstypes = SplitterT()(); + const std::vector expected_argstypes = SplitterT().arg_types(); + const std::vector fundamental_ptr_argstypes = SplitterT().fundamental_ptr_types(); ArrayRef argtypes(data.argtypes); const int nb_args = expected_argstypes.size(); if(nb_args != static_cast(argtypes.size())) @@ -187,7 +228,7 @@ typename detail::SplitSignature::fptr_t make_function_pointer(SafeCF for(int i = 0; i != nb_args; ++i) { jl_datatype_t* argt = (jl_datatype_t*)argtypes[i]; - if(argt != expected_argstypes[i]) + if(argt != expected_argstypes[i] && argt != fundamental_ptr_argstypes[i]) { std::stringstream err_sstr; err_sstr << "Incorrect argument type for cfunction at position " << i+1 << ", expected: " << julia_type_name(expected_argstypes[i]) << ", obtained: " << julia_type_name(argt); diff --git a/include/jlcxx/jlcxx_config.hpp b/include/jlcxx/jlcxx_config.hpp index b854a93..1ff5c30 100644 --- a/include/jlcxx/jlcxx_config.hpp +++ b/include/jlcxx/jlcxx_config.hpp @@ -16,7 +16,7 @@ #define JLCXX_VERSION_MAJOR 0 #define JLCXX_VERSION_MINOR 14 -#define JLCXX_VERSION_PATCH 7 +#define JLCXX_VERSION_PATCH 8 // From https://stackoverflow.com/questions/5459868/concatenate-int-to-string-using-c-preprocessor #define __JLCXX_STR_HELPER(x) #x