Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ pub enum BridgeRsTypeKind {
StdString {
in_cc_std: bool,
},
DynCallable(Rc<Callable>),
Callable(Rc<Callable>),
/// c9::Co<T>
C9Co {
has_reference_param: bool,
Expand Down Expand Up @@ -625,7 +625,7 @@ impl BridgeRsTypeKind {
BridgeRsTypeKind::StdString { in_cc_std }
}
BridgeType::Callable { backing_type, fn_trait, return_type, param_types } => {
BridgeRsTypeKind::DynCallable(Rc::new(Callable {
BridgeRsTypeKind::Callable(Rc::new(Callable {
backing_type,
fn_trait: match fn_trait {
ir::FnTrait::Fn => FnTrait::Fn,
Expand Down Expand Up @@ -1206,8 +1206,8 @@ impl RsTypeKind {
BridgeRsTypeKind::StdOptional(t) => t.implements_copy(),
BridgeRsTypeKind::StdPair(t1, t2) => t1.implements_copy() && t2.implements_copy(),
BridgeRsTypeKind::StdString { .. } => false,
BridgeRsTypeKind::DynCallable { .. } => {
// DynCallable represents an owned function object, so it is not copyable.
BridgeRsTypeKind::Callable { .. } => {
// Callables represent an owned function object, so they are not copyable.
false
}
BridgeRsTypeKind::C9Co { .. } => false,
Expand Down Expand Up @@ -1708,9 +1708,9 @@ impl RsTypeKind {
quote! { ::cc_std::std::string }
}
}
BridgeRsTypeKind::DynCallable(dyn_callable) => {
let dyn_callable_spelling = dyn_callable.dyn_fn_spelling(&db);
quote! { ::alloc::boxed::Box<#dyn_callable_spelling> }
BridgeRsTypeKind::Callable(callable) => {
let callable_spelling = callable.dyn_fn_spelling(&db);
quote! { ::alloc::boxed::Box<#callable_spelling> }
}
BridgeRsTypeKind::C9Co { has_reference_param, result_type, .. } => {
let result_type_tokens = if result_type.is_void() {
Expand Down Expand Up @@ -1866,7 +1866,7 @@ impl<'ty> Iterator for RsTypeKindIter<'ty> {
self.todo.push(t1);
}
BridgeRsTypeKind::StdString { .. } => {}
BridgeRsTypeKind::DynCallable(callable) => {
BridgeRsTypeKind::Callable(callable) => {
self.todo.push(&callable.return_type);
self.todo.extend(callable.param_types.iter().rev());
}
Expand Down
186 changes: 162 additions & 24 deletions rs_bindings_from_cc/generate_bindings/generate_dyn_callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

use arc_anyhow::{bail, Result};
use crubit_abi_type::{CrubitAbiType, CrubitAbiTypeToCppExprTokens, CrubitAbiTypeToCppTokens};
use crubit_abi_type::{
CrubitAbiType, CrubitAbiTypeToCppExprTokens, CrubitAbiTypeToCppTokens,
CrubitAbiTypeToRustExprTokens, CrubitAbiTypeToRustTokens,
};
use database::db::BindingsGenerator;
use database::rs_snippet::{BackingType, Callable, FnTrait, PassingConvention, RsTypeKind};
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -44,7 +47,7 @@ pub fn dyn_callable_crubit_abi_type(
)
},
BackingType::AnyInvocable => {
let make_cpp_invoker_tokens = generate_make_cpp_invoker_tokens()?;
let make_cpp_invoker_tokens = generate_make_cpp_invoker_tokens(db, callable)?;
quote! {
::any_invocable::AnyInvocableAbi::<#dyn_fn_spelling>::new(
#on_empty_tokens,
Expand Down Expand Up @@ -106,7 +109,7 @@ pub fn dyn_callable_crubit_abi_type(
})
}

/// Generates the function pointer object that DynCallable will use in operator()().
/// Generates the function pointer object that the callable will use in operator()().
///
/// This will often produce tokens of the form:
///
Expand Down Expand Up @@ -142,7 +145,7 @@ fn generate_invoker_function_pointer(
let param_ident = &param_idents[i];

match param_ty.passing_convention() {
PassingConvention::AbiCompatible => {
PassingConvention::AbiCompatible | PassingConvention::OwnedPtr => {
arg_exprs.push(quote! { #param_ident });
}
PassingConvention::LayoutCompatible => {
Expand All @@ -159,23 +162,22 @@ fn generate_invoker_function_pointer(
let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens(&crubit_abi_type);
let arg_ident = format_ident!("bridge_param_{i}");
arg_transforms.extend(quote! {
unsigned char #arg_ident[#crubit_abi_type_tokens::kSize];
::crubit::internal::Encode(#crubit_abi_type_expr_tokens, #arg_ident, #param_ident);
});
unsigned char #arg_ident[#crubit_abi_type_tokens::kSize];
::crubit::internal::Encode(#crubit_abi_type_expr_tokens, #arg_ident, #param_ident);
});
arg_exprs.push(quote! { #arg_ident });
}
PassingConvention::Ctor => {
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
bail!("OwnedPtr not supported");
}
PassingConvention::Void => unreachable!("parameter types cannot be void"),
PassingConvention::Void => bail!("parameter types cannot be void"),
}
}

let out_param_arg = match callable.return_type.passing_convention() {
PassingConvention::AbiCompatible | PassingConvention::Void => None,
PassingConvention::AbiCompatible
| PassingConvention::Void
| PassingConvention::OwnedPtr => None,
PassingConvention::LayoutCompatible => {
arg_transforms.extend(quote! {
::crubit::Slot<#cpp_return_type> out;
Expand All @@ -193,17 +195,14 @@ fn generate_invoker_function_pointer(
PassingConvention::Ctor => {
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
bail!("OwnedPtr not supported");
}
};

let mut invoke_ffi_and_transform_to_cpp = quote! {
#invoker_ident(state #(, #arg_exprs)* #out_param_arg);
};

match callable.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
PassingConvention::AbiCompatible | PassingConvention::OwnedPtr => {
// Return the result.
invoke_ffi_and_transform_to_cpp = quote! {
return #invoke_ffi_and_transform_to_cpp
Expand All @@ -221,16 +220,13 @@ fn generate_invoker_function_pointer(
let crubit_abi_type_tokens = CrubitAbiTypeToCppTokens(&crubit_abi_type);
let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens(&crubit_abi_type);
invoke_ffi_and_transform_to_cpp.extend(quote! {
// Because our bridge buffer is named `out`
return ::crubit::internal::Decode<#crubit_abi_type_tokens>(#crubit_abi_type_expr_tokens, out);
});
// Because our bridge buffer is named `out`
return ::crubit::internal::Decode<#crubit_abi_type_tokens>(#crubit_abi_type_expr_tokens, out);
});
}
PassingConvention::Ctor => {
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
bail!("OwnedPtr not supported");
}
PassingConvention::Void => {
// No need to return anything.
}
Expand All @@ -251,6 +247,148 @@ fn generate_invoker_function_pointer(
}

/// Generates the `make_cpp_invoker` function for AnyInvocable.
fn generate_make_cpp_invoker_tokens() -> Result<TokenStream> {
bail!("AnyInvocable is not yet supported")
///
/// It's a closure that takes a manager and an invoker, and produces a boxed dyn fn that uses the
/// manager and invoker to do the actual work.
///
/// The produced function needs to know how to convert values to and from C++.
fn generate_make_cpp_invoker_tokens(
db: &BindingsGenerator,
callable: &Callable,
) -> Result<TokenStream> {
let param_idents =
(0..callable.param_types.len()).map(|i| format_ident!("param_{i}")).collect::<Vec<_>>();
let rust_param_types = callable.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
let rust_return_type_fragment = callable.rust_return_type_fragment(db);

let mut c_param_types = Vec::with_capacity(callable.param_types.len());
let mut arg_exprs = Vec::with_capacity(callable.param_types.len());
// We are the caller
for (i, param_ty) in callable.param_types.iter().enumerate() {
let param_ident = &param_idents[i];

match param_ty.passing_convention() {
PassingConvention::AbiCompatible => {
c_param_types.push(param_ty.to_token_stream(db));
arg_exprs.push(quote! { #param_ident });
}
PassingConvention::LayoutCompatible => {
let param_ty_tokens = param_ty.to_token_stream(db);
c_param_types.push(quote! { &mut #param_ty_tokens });
arg_exprs.push(quote! { &mut #param_ident });
}
PassingConvention::ComposablyBridged => {
let crubit_abi_type = db.crubit_abi_type(param_ty.clone())?;
let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens(&crubit_abi_type);
let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens(&crubit_abi_type);
// For arguments that are bridge types, we encode the
// Rust value into a buffer and then the argument is a pointer to that buffer.
c_param_types.push(quote! { *const u8 });
arg_exprs.push(quote! {
::bridge_rust::unstable_encode!(@ #crubit_abi_type_expr_tokens, #crubit_abi_type_tokens, #param_ident)
.as_ptr() as *const u8
});
}
PassingConvention::Ctor => {
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
c_param_types.push(param_ty.to_token_stream_with_owned_ptr_type(db));
arg_exprs.push(quote! {
// SAFETY: Transmuting from a repr(transparent) struct that wraps the pointer.
unsafe { ::core::mem::transmute(#param_ident) }
});
}
PassingConvention::Void => bail!("parameter types cannot be void"),
}
}

// What the extern "C" function should return.
let mut c_return_type_fragment = None;
// Set c_return_type_fragment, or push an out param, or nothing if void.
match callable.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
let c_return_type = callable.return_type.to_token_stream(db);
c_return_type_fragment = Some(quote! { -> #c_return_type });
}
PassingConvention::Void => {}
PassingConvention::LayoutCompatible => {
let return_type_tokens = callable.return_type.to_token_stream(db);
c_param_types.push(quote! { *mut #return_type_tokens });
arg_exprs.push(quote! { &raw mut out });
}
PassingConvention::ComposablyBridged => {
c_param_types.push(quote! { *mut u8 });
arg_exprs.push(quote! { &raw mut out });
}
PassingConvention::Ctor => {
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
let c_return_type = callable.return_type.to_token_stream_with_owned_ptr_type(db);
c_return_type_fragment = Some(quote! { -> #c_return_type });
}
};

let mut invoke_ffi_and_transform_to_rust = quote! {
unsafe { c_invoker(managed.state() #(, #arg_exprs)*) }
};

match callable.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
// invoke_ffi_and_transform_to_rust is already a trailing expr.
}
PassingConvention::LayoutCompatible => {
invoke_ffi_and_transform_to_rust = quote! {
let out = ::core::mem::MaybeUninit::uninit();
#invoke_ffi_and_transform_to_rust;
unsafe { out.assume_init() }
}
}
PassingConvention::ComposablyBridged => {
let crubit_abi_type = db.crubit_abi_type(callable.return_type.as_ref().clone())?;
let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens(&crubit_abi_type);
let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens(&crubit_abi_type);
invoke_ffi_and_transform_to_rust.extend(quote! {
::bridge_rust::unstable_return!(@ #crubit_abi_type_expr_tokens, #crubit_abi_type_tokens, |out| {
#invoke_ffi_and_transform_to_rust
})
});
}
PassingConvention::Ctor => {
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
invoke_ffi_and_transform_to_rust = quote! {
// SAFETY: Transmuting to a repr(transparent) struct that wraps the pointer.
unsafe { ::core::mem::transmute(#invoke_ffi_and_transform_to_rust) }
};
}
PassingConvention::Void => {
// Append semicolon to the statement.
invoke_ffi_and_transform_to_rust = quote! {
#invoke_ffi_and_transform_to_rust;
}
}
}

let dyn_fn_spelling = callable.dyn_fn_spelling(db);

Ok(quote! {
|managed: ::any_invocable::ManagedState,
invoker: unsafe extern "C" fn()| -> ::alloc::boxed::Box<#dyn_fn_spelling> {
let c_invoker = unsafe {
::core::mem::transmute::<
unsafe extern "C" fn(),
unsafe extern "C" fn(
*mut ::any_invocable::TypeErasedState
#( , #c_param_types )*
) #c_return_type_fragment
>(invoker)
};
::alloc::boxed::Box::new(move |#( #param_idents: #rust_param_types ),*| #rust_return_type_fragment {
#invoke_ffi_and_transform_to_rust
})
}
})
}
35 changes: 23 additions & 12 deletions rs_bindings_from_cc/generate_bindings/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ pub fn generate_bindings_tokens(
let has_reference_param = false;

// Find records that are template instantiations of `rs_std::DynCallable`.
let Ok(Some(BridgeRsTypeKind::DynCallable(dyn_callable))) =
let Ok(Some(BridgeRsTypeKind::Callable(callable))) =
BridgeRsTypeKind::new(record, has_reference_param, &db)
else {
return None;
Expand All @@ -413,21 +413,21 @@ pub fn generate_bindings_tokens(
// The parameters shall be named `param_0`, `param_1`, etc.
// These names can be reused across different dyn callables, so we reuse the same vec
// and just grow it when we need more Idents than it currently contains.
while dyn_callable.param_types.len() > param_idents_buffer.len() {
while callable.param_types.len() > param_idents_buffer.len() {
param_idents_buffer.push(format_ident!("param_{}", param_idents_buffer.len()));
}
// Only take as many filled in names as we need.
let param_idents = &param_idents_buffer[..dyn_callable.param_types.len()];
let param_idents = &param_idents_buffer[..callable.param_types.len()];

// If generate_dyn_callable_cpp_thunk fails, skip. We don't need to generate a nice
// error because whoever uses this will also fail and generate an error at the relevant
// site.
let dyn_callable_cpp_decl =
generate_dyn_callable_cpp_thunk(&db, &dyn_callable, param_idents)?;
let dyn_callable_rust_impl =
generate_dyn_callable_rust_thunk_impl(&db, dyn_callable.clone(), param_idents)?;
let callable_cpp_decl =
generate_dyn_callable_cpp_thunk(&db, &callable, param_idents)?;
let callable_rust_impl =
generate_dyn_callable_rust_thunk_impl(&db, callable.clone(), param_idents)?;

Some((dyn_callable_cpp_decl, dyn_callable_rust_impl))
Some((callable_cpp_decl, callable_rust_impl))
})
.unzip();

Expand Down Expand Up @@ -602,8 +602,8 @@ fn rs_type_kind_safety(db: &BindingsGenerator, rs_type_kind: RsTypeKind) -> Safe
}
}
BridgeRsTypeKind::StdString { .. } => Safety::Safe,
BridgeRsTypeKind::DynCallable(dyn_callable) => {
callable_safety(db, &dyn_callable.param_types, &dyn_callable.return_type)
BridgeRsTypeKind::Callable(callable) => {
callable_safety(db, &callable.param_types, &callable.return_type)
}
BridgeRsTypeKind::C9Co { result_type, .. } => {
// A Co<T> logically produces a T, so it is unsafe iff T is unsafe.
Expand Down Expand Up @@ -785,6 +785,17 @@ fn generate_rs_api_impl_includes(
"util/c9/internal/rust/co_crubit_abi.h".into(),
));
}
BridgeRsTypeKind::Callable(callable)
if callable.backing_type == BackingType::AnyInvocable =>
{
internal_includes.insert(CcInclude::SupportLibHeader(
crubit_support_path_format.clone(),
"bridge.h".into(),
));
internal_includes.insert(CcInclude::user_header(
"third_party/absl/functional/any_invocable_crubit_abi.h".into(),
));
}
_ => {
internal_includes.insert(CcInclude::SupportLibHeader(
crubit_support_path_format.clone(),
Expand Down Expand Up @@ -1013,12 +1024,12 @@ fn crubit_abi_type(db: &BindingsGenerator, rs_type_kind: RsTypeKind) -> Result<C
Ok(CrubitAbiType::Pair(Rc::from(first_abi), Rc::from(second_abi)))
}
BridgeRsTypeKind::StdString { in_cc_std } => Ok(CrubitAbiType::StdString { in_cc_std }),
BridgeRsTypeKind::DynCallable(dyn_callable) => {
BridgeRsTypeKind::Callable(callable) => {
ensure!(
db.ir().target_crubit_features(&original_type.owning_target).contains(CrubitFeature::Callables),
"Callables require the `callables` feature, but target `{:?}` does not have it enabled.", original_type.owning_target,
);
generate_dyn_callable::dyn_callable_crubit_abi_type(db, &dyn_callable)
generate_dyn_callable::dyn_callable_crubit_abi_type(db, &callable)
}
BridgeRsTypeKind::C9Co { result_type, .. } => {
let result_type_tokens = if result_type.is_void() {
Expand Down
Loading