Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,20 @@ JLCXX_MODULE init_test_module(jlcxx::Module& mod)
{
b = !b;
});

mod.method("test_safe_cfunction_uint", [](jlcxx::SafeCFunction f_data)
{
auto f = jlcxx::make_function_pointer<int(unsigned int*,int)>(f_data);
unsigned int buffer[] = {1,2,3};
return f(buffer, 3);
});

mod.method("test_safe_cfunction_uint64", [](jlcxx::SafeCFunction f_data)
{
auto f = jlcxx::make_function_pointer<uint64_t*(int)>(f_data);
uint64_t* buffer = f(3);
return buffer[0] + buffer[1] + buffer[2];
});
}

}
51 changes: 46 additions & 5 deletions include/jlcxx/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@ struct ConvertToCpp<SafeCFunction>

namespace detail
{
// Allow fundamental pointer types to be passed as e.g. Ptr{Int32} instead of CxxPtr{Int32}
template<typename T>
struct FundamentalPtrT
{
static jl_datatype_t* value()
{
return julia_type<T>();
}
};

template<typename T>
struct FundamentalPtrT<T*>
{
static jl_datatype_t* value()
{
if constexpr (std::is_fundamental_v<T>)
{
return (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)julia_type<T>());
}
return julia_type<T*>();
}
};

template<typename SignatureT>
struct SplitSignature;

Expand All @@ -146,11 +169,28 @@ namespace detail
typedef R return_type;
typedef R(*fptr_t)(ArgsT...);

std::vector<jl_datatype_t*> operator()()
jl_datatype_t* expected_return_type()
{
create_if_not_exists<R>();
return julia_type<return_type>();
}

jl_datatype_t* fundamental_ptr_return_type()
{
return FundamentalPtrT<return_type>::value();
}

std::vector<jl_datatype_t*> arg_types()
{
(create_if_not_exists<ArgsT>(), ...);
return std::vector<jl_datatype_t*>({julia_type<ArgsT>()...});
}

std::vector<jl_datatype_t*> fundamental_ptr_types()
{
return std::vector<jl_datatype_t*>({FundamentalPtrT<ArgsT>::value()...});
}

fptr_t cast_ptr(void* ptr)
{
return reinterpret_cast<fptr_t>(ptr);
Expand All @@ -166,15 +206,16 @@ typename detail::SplitSignature<SignatureT>::fptr_t make_function_pointer(SafeCF
JL_GC_PUSH3(&data.fptr, &data.return_type, &data.argtypes);

// Check return type
jl_datatype_t* expected_rt = julia_type<typename SplitterT::return_type>();
if(expected_rt != data.return_type)
jl_datatype_t* expected_rt = SplitterT().expected_return_type();
if(expected_rt != data.return_type && SplitterT().fundamental_ptr_return_type() != data.return_type)
{
JL_GC_POP();
throw std::runtime_error("Incorrect datatype for cfunction return type, expected " + julia_type_name(expected_rt) + " but got " + julia_type_name(data.return_type));
}

// Check arguments
const std::vector<jl_datatype_t*> expected_argstypes = SplitterT()();
const std::vector<jl_datatype_t*> expected_argstypes = SplitterT().arg_types();
const std::vector<jl_datatype_t*> fundamental_ptr_argstypes = SplitterT().fundamental_ptr_types();
ArrayRef<jl_value_t*> argtypes(data.argtypes);
const int nb_args = expected_argstypes.size();
if(nb_args != static_cast<int>(argtypes.size()))
Expand All @@ -187,7 +228,7 @@ typename detail::SplitSignature<SignatureT>::fptr_t make_function_pointer(SafeCF
for(int i = 0; i != nb_args; ++i)
{
jl_datatype_t* argt = (jl_datatype_t*)argtypes[i];
if(argt != expected_argstypes[i])
if(argt != expected_argstypes[i] && argt != fundamental_ptr_argstypes[i])
{
std::stringstream err_sstr;
err_sstr << "Incorrect argument type for cfunction at position " << i+1 << ", expected: " << julia_type_name(expected_argstypes[i]) << ", obtained: " << julia_type_name(argt);
Expand Down
2 changes: 1 addition & 1 deletion include/jlcxx/jlcxx_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#define JLCXX_VERSION_MAJOR 0
#define JLCXX_VERSION_MINOR 14
#define JLCXX_VERSION_PATCH 7
#define JLCXX_VERSION_PATCH 8

// From https://stackoverflow.com/questions/5459868/concatenate-int-to-string-using-c-preprocessor
#define __JLCXX_STR_HELPER(x) #x
Expand Down