Skip to content

Commit d7dde5b

Browse files
Googlercopybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 869804191
1 parent 706b41d commit d7dde5b

File tree

9 files changed

+458
-53
lines changed

9 files changed

+458
-53
lines changed

rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ pub enum BridgeRsTypeKind {
562562
StdString {
563563
in_cc_std: bool,
564564
},
565-
DynCallable(Rc<Callable>),
565+
Callable(Rc<Callable>),
566566
/// c9::Co<T>
567567
C9Co {
568568
has_reference_param: bool,
@@ -625,7 +625,7 @@ impl BridgeRsTypeKind {
625625
BridgeRsTypeKind::StdString { in_cc_std }
626626
}
627627
BridgeType::Callable { backing_type, fn_trait, return_type, param_types } => {
628-
BridgeRsTypeKind::DynCallable(Rc::new(Callable {
628+
BridgeRsTypeKind::Callable(Rc::new(Callable {
629629
backing_type,
630630
fn_trait: match fn_trait {
631631
ir::FnTrait::Fn => FnTrait::Fn,
@@ -1206,8 +1206,8 @@ impl RsTypeKind {
12061206
BridgeRsTypeKind::StdOptional(t) => t.implements_copy(),
12071207
BridgeRsTypeKind::StdPair(t1, t2) => t1.implements_copy() && t2.implements_copy(),
12081208
BridgeRsTypeKind::StdString { .. } => false,
1209-
BridgeRsTypeKind::DynCallable { .. } => {
1210-
// DynCallable represents an owned function object, so it is not copyable.
1209+
BridgeRsTypeKind::Callable { .. } => {
1210+
// Callables represent an owned function object, so they are not copyable.
12111211
false
12121212
}
12131213
BridgeRsTypeKind::C9Co { .. } => false,
@@ -1708,9 +1708,9 @@ impl RsTypeKind {
17081708
quote! { ::cc_std::std::string }
17091709
}
17101710
}
1711-
BridgeRsTypeKind::DynCallable(dyn_callable) => {
1712-
let dyn_callable_spelling = dyn_callable.dyn_fn_spelling(&db);
1713-
quote! { ::alloc::boxed::Box<#dyn_callable_spelling> }
1711+
BridgeRsTypeKind::Callable(callable) => {
1712+
let callable_spelling = callable.dyn_fn_spelling(&db);
1713+
quote! { ::alloc::boxed::Box<#callable_spelling> }
17141714
}
17151715
BridgeRsTypeKind::C9Co { has_reference_param, result_type, .. } => {
17161716
let result_type_tokens = if result_type.is_void() {
@@ -1866,7 +1866,7 @@ impl<'ty> Iterator for RsTypeKindIter<'ty> {
18661866
self.todo.push(t1);
18671867
}
18681868
BridgeRsTypeKind::StdString { .. } => {}
1869-
BridgeRsTypeKind::DynCallable(callable) => {
1869+
BridgeRsTypeKind::Callable(callable) => {
18701870
self.todo.push(&callable.return_type);
18711871
self.todo.extend(callable.param_types.iter().rev());
18721872
}

rs_bindings_from_cc/generate_bindings/generate_dyn_callable.rs

Lines changed: 162 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
use arc_anyhow::{bail, Result};
6-
use crubit_abi_type::{CrubitAbiType, CrubitAbiTypeToCppExprTokens, CrubitAbiTypeToCppTokens};
6+
use crubit_abi_type::{
7+
CrubitAbiType, CrubitAbiTypeToCppExprTokens, CrubitAbiTypeToCppTokens,
8+
CrubitAbiTypeToRustExprTokens, CrubitAbiTypeToRustTokens,
9+
};
710
use database::db::BindingsGenerator;
811
use database::rs_snippet::{BackingType, Callable, FnTrait, PassingConvention, RsTypeKind};
912
use proc_macro2::TokenStream;
@@ -44,7 +47,7 @@ pub fn dyn_callable_crubit_abi_type(
4447
)
4548
},
4649
BackingType::AnyInvocable => {
47-
let make_cpp_invoker_tokens = generate_make_cpp_invoker_tokens()?;
50+
let make_cpp_invoker_tokens = generate_make_cpp_invoker_tokens(db, callable)?;
4851
quote! {
4952
::any_invocable::AnyInvocableAbi::<#dyn_fn_spelling>::new(
5053
#on_empty_tokens,
@@ -106,7 +109,7 @@ pub fn dyn_callable_crubit_abi_type(
106109
})
107110
}
108111

109-
/// Generates the function pointer object that DynCallable will use in operator()().
112+
/// Generates the function pointer object that the callable will use in operator()().
110113
///
111114
/// This will often produce tokens of the form:
112115
///
@@ -142,7 +145,7 @@ fn generate_invoker_function_pointer(
142145
let param_ident = &param_idents[i];
143146

144147
match param_ty.passing_convention() {
145-
PassingConvention::AbiCompatible => {
148+
PassingConvention::AbiCompatible | PassingConvention::OwnedPtr => {
146149
arg_exprs.push(quote! { #param_ident });
147150
}
148151
PassingConvention::LayoutCompatible => {
@@ -159,23 +162,22 @@ fn generate_invoker_function_pointer(
159162
let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens(&crubit_abi_type);
160163
let arg_ident = format_ident!("bridge_param_{i}");
161164
arg_transforms.extend(quote! {
162-
unsigned char #arg_ident[#crubit_abi_type_tokens::kSize];
163-
::crubit::internal::Encode(#crubit_abi_type_expr_tokens, #arg_ident, #param_ident);
164-
});
165+
unsigned char #arg_ident[#crubit_abi_type_tokens::kSize];
166+
::crubit::internal::Encode(#crubit_abi_type_expr_tokens, #arg_ident, #param_ident);
167+
});
165168
arg_exprs.push(quote! { #arg_ident });
166169
}
167170
PassingConvention::Ctor => {
168171
bail!("Ctor not supported");
169172
}
170-
PassingConvention::OwnedPtr => {
171-
bail!("OwnedPtr not supported");
172-
}
173-
PassingConvention::Void => unreachable!("parameter types cannot be void"),
173+
PassingConvention::Void => bail!("parameter types cannot be void"),
174174
}
175175
}
176176

177177
let out_param_arg = match callable.return_type.passing_convention() {
178-
PassingConvention::AbiCompatible | PassingConvention::Void => None,
178+
PassingConvention::AbiCompatible
179+
| PassingConvention::Void
180+
| PassingConvention::OwnedPtr => None,
179181
PassingConvention::LayoutCompatible => {
180182
arg_transforms.extend(quote! {
181183
::crubit::Slot<#cpp_return_type> out;
@@ -193,17 +195,14 @@ fn generate_invoker_function_pointer(
193195
PassingConvention::Ctor => {
194196
bail!("Ctor not supported");
195197
}
196-
PassingConvention::OwnedPtr => {
197-
bail!("OwnedPtr not supported");
198-
}
199198
};
200199

201200
let mut invoke_ffi_and_transform_to_cpp = quote! {
202201
#invoker_ident(state #(, #arg_exprs)* #out_param_arg);
203202
};
204203

205204
match callable.return_type.passing_convention() {
206-
PassingConvention::AbiCompatible => {
205+
PassingConvention::AbiCompatible | PassingConvention::OwnedPtr => {
207206
// Return the result.
208207
invoke_ffi_and_transform_to_cpp = quote! {
209208
return #invoke_ffi_and_transform_to_cpp
@@ -221,16 +220,13 @@ fn generate_invoker_function_pointer(
221220
let crubit_abi_type_tokens = CrubitAbiTypeToCppTokens(&crubit_abi_type);
222221
let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens(&crubit_abi_type);
223222
invoke_ffi_and_transform_to_cpp.extend(quote! {
224-
// Because our bridge buffer is named `out`
225-
return ::crubit::internal::Decode<#crubit_abi_type_tokens>(#crubit_abi_type_expr_tokens, out);
226-
});
223+
// Because our bridge buffer is named `out`
224+
return ::crubit::internal::Decode<#crubit_abi_type_tokens>(#crubit_abi_type_expr_tokens, out);
225+
});
227226
}
228227
PassingConvention::Ctor => {
229228
bail!("Ctor not supported");
230229
}
231-
PassingConvention::OwnedPtr => {
232-
bail!("OwnedPtr not supported");
233-
}
234230
PassingConvention::Void => {
235231
// No need to return anything.
236232
}
@@ -251,6 +247,148 @@ fn generate_invoker_function_pointer(
251247
}
252248

253249
/// Generates the `make_cpp_invoker` function for AnyInvocable.
254-
fn generate_make_cpp_invoker_tokens() -> Result<TokenStream> {
255-
bail!("AnyInvocable is not yet supported")
250+
///
251+
/// It's a closure that takes a manager and an invoker, and produces a boxed dyn fn that uses the
252+
/// manager and invoker to do the actual work.
253+
///
254+
/// The produced function needs to know how to convert values to and from C++.
255+
fn generate_make_cpp_invoker_tokens(
256+
db: &BindingsGenerator,
257+
callable: &Callable,
258+
) -> Result<TokenStream> {
259+
let param_idents =
260+
(0..callable.param_types.len()).map(|i| format_ident!("param_{i}")).collect::<Vec<_>>();
261+
let rust_param_types = callable.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
262+
let rust_return_type_fragment = callable.rust_return_type_fragment(db);
263+
264+
let mut c_param_types = Vec::with_capacity(callable.param_types.len());
265+
let mut arg_exprs = Vec::with_capacity(callable.param_types.len());
266+
// We are the caller
267+
for (i, param_ty) in callable.param_types.iter().enumerate() {
268+
let param_ident = &param_idents[i];
269+
270+
match param_ty.passing_convention() {
271+
PassingConvention::AbiCompatible => {
272+
c_param_types.push(param_ty.to_token_stream(db));
273+
arg_exprs.push(quote! { #param_ident });
274+
}
275+
PassingConvention::LayoutCompatible => {
276+
let param_ty_tokens = param_ty.to_token_stream(db);
277+
c_param_types.push(quote! { &mut #param_ty_tokens });
278+
arg_exprs.push(quote! { &mut #param_ident });
279+
}
280+
PassingConvention::ComposablyBridged => {
281+
let crubit_abi_type = db.crubit_abi_type(param_ty.clone())?;
282+
let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens(&crubit_abi_type);
283+
let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens(&crubit_abi_type);
284+
// For arguments that are bridge types, we encode the
285+
// Rust value into a buffer and then the argument is a pointer to that buffer.
286+
c_param_types.push(quote! { *const u8 });
287+
arg_exprs.push(quote! {
288+
::bridge_rust::unstable_encode!(@ #crubit_abi_type_expr_tokens, #crubit_abi_type_tokens, #param_ident)
289+
.as_ptr() as *const u8
290+
});
291+
}
292+
PassingConvention::Ctor => {
293+
bail!("Ctor not supported");
294+
}
295+
PassingConvention::OwnedPtr => {
296+
c_param_types.push(param_ty.to_token_stream_with_owned_ptr_type(db));
297+
arg_exprs.push(quote! {
298+
// SAFETY: Transmuting from a repr(transparent) struct that wraps the pointer.
299+
unsafe { ::core::mem::transmute(#param_ident) }
300+
});
301+
}
302+
PassingConvention::Void => bail!("parameter types cannot be void"),
303+
}
304+
}
305+
306+
// What the extern "C" function should return.
307+
let mut c_return_type_fragment = None;
308+
// Set c_return_type_fragment, or push an out param, or nothing if void.
309+
match callable.return_type.passing_convention() {
310+
PassingConvention::AbiCompatible => {
311+
let c_return_type = callable.return_type.to_token_stream(db);
312+
c_return_type_fragment = Some(quote! { -> #c_return_type });
313+
}
314+
PassingConvention::Void => {}
315+
PassingConvention::LayoutCompatible => {
316+
let return_type_tokens = callable.return_type.to_token_stream(db);
317+
c_param_types.push(quote! { *mut #return_type_tokens });
318+
arg_exprs.push(quote! { &raw mut out });
319+
}
320+
PassingConvention::ComposablyBridged => {
321+
c_param_types.push(quote! { *mut u8 });
322+
arg_exprs.push(quote! { &raw mut out });
323+
}
324+
PassingConvention::Ctor => {
325+
bail!("Ctor not supported");
326+
}
327+
PassingConvention::OwnedPtr => {
328+
let c_return_type = callable.return_type.to_token_stream_with_owned_ptr_type(db);
329+
c_return_type_fragment = Some(quote! { -> #c_return_type });
330+
}
331+
};
332+
333+
let mut invoke_ffi_and_transform_to_rust = quote! {
334+
unsafe { c_invoker(managed.state.get() #(, #arg_exprs)*) }
335+
};
336+
337+
match callable.return_type.passing_convention() {
338+
PassingConvention::AbiCompatible => {
339+
// invoke_ffi_and_transform_to_rust is already a trailing expr.
340+
}
341+
PassingConvention::LayoutCompatible => {
342+
invoke_ffi_and_transform_to_rust = quote! {
343+
let out = ::core::mem::MaybeUninit::uninit();
344+
#invoke_ffi_and_transform_to_rust;
345+
unsafe { out.assume_init() }
346+
}
347+
}
348+
PassingConvention::ComposablyBridged => {
349+
let crubit_abi_type = db.crubit_abi_type(callable.return_type.as_ref().clone())?;
350+
let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens(&crubit_abi_type);
351+
let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens(&crubit_abi_type);
352+
invoke_ffi_and_transform_to_rust.extend(quote! {
353+
::bridge_rust::unstable_return!(@ #crubit_abi_type_expr_tokens, #crubit_abi_type_tokens, |out| {
354+
#invoke_ffi_and_transform_to_rust
355+
})
356+
});
357+
}
358+
PassingConvention::Ctor => {
359+
bail!("Ctor not supported");
360+
}
361+
PassingConvention::OwnedPtr => {
362+
invoke_ffi_and_transform_to_rust = quote! {
363+
// SAFETY: Transmuting to a repr(transparent) struct that wraps the pointer.
364+
unsafe { ::core::mem::transmute(#invoke_ffi_and_transform_to_rust) }
365+
};
366+
}
367+
PassingConvention::Void => {
368+
// Append semicolon to the statement.
369+
invoke_ffi_and_transform_to_rust = quote! {
370+
#invoke_ffi_and_transform_to_rust;
371+
}
372+
}
373+
}
374+
375+
let dyn_fn_spelling = callable.dyn_fn_spelling(db);
376+
377+
Ok(quote! {
378+
|managed: ::any_invocable::ManagedState,
379+
invoker: unsafe extern "C" fn()| -> ::alloc::boxed::Box<#dyn_fn_spelling> {
380+
let c_invoker = unsafe {
381+
::core::mem::transmute::<
382+
unsafe extern "C" fn(),
383+
unsafe extern "C" fn(
384+
*mut ::any_invocable::TypeErasedState
385+
#( , #c_param_types )*
386+
) #c_return_type_fragment
387+
>(invoker)
388+
};
389+
::alloc::boxed::Box::new(move |#( #param_idents: #rust_param_types ),*| #rust_return_type_fragment {
390+
#invoke_ffi_and_transform_to_rust
391+
})
392+
}
393+
})
256394
}

rs_bindings_from_cc/generate_bindings/lib.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ pub fn generate_bindings_tokens(
404404
let has_reference_param = false;
405405

406406
// Find records that are template instantiations of `rs_std::DynCallable`.
407-
let Ok(Some(BridgeRsTypeKind::DynCallable(dyn_callable))) =
407+
let Ok(Some(BridgeRsTypeKind::Callable(callable))) =
408408
BridgeRsTypeKind::new(record, has_reference_param, &db)
409409
else {
410410
return None;
@@ -413,21 +413,21 @@ pub fn generate_bindings_tokens(
413413
// The parameters shall be named `param_0`, `param_1`, etc.
414414
// These names can be reused across different dyn callables, so we reuse the same vec
415415
// and just grow it when we need more Idents than it currently contains.
416-
while dyn_callable.param_types.len() > param_idents_buffer.len() {
416+
while callable.param_types.len() > param_idents_buffer.len() {
417417
param_idents_buffer.push(format_ident!("param_{}", param_idents_buffer.len()));
418418
}
419419
// Only take as many filled in names as we need.
420-
let param_idents = &param_idents_buffer[..dyn_callable.param_types.len()];
420+
let param_idents = &param_idents_buffer[..callable.param_types.len()];
421421

422422
// If generate_dyn_callable_cpp_thunk fails, skip. We don't need to generate a nice
423423
// error because whoever uses this will also fail and generate an error at the relevant
424424
// site.
425-
let dyn_callable_cpp_decl =
426-
generate_dyn_callable_cpp_thunk(&db, &dyn_callable, param_idents)?;
427-
let dyn_callable_rust_impl =
428-
generate_dyn_callable_rust_thunk_impl(&db, dyn_callable.clone(), param_idents)?;
425+
let callable_cpp_decl =
426+
generate_dyn_callable_cpp_thunk(&db, &callable, param_idents)?;
427+
let callable_rust_impl =
428+
generate_dyn_callable_rust_thunk_impl(&db, callable.clone(), param_idents)?;
429429

430-
Some((dyn_callable_cpp_decl, dyn_callable_rust_impl))
430+
Some((callable_cpp_decl, callable_rust_impl))
431431
})
432432
.unzip();
433433

@@ -602,8 +602,8 @@ fn rs_type_kind_safety(db: &BindingsGenerator, rs_type_kind: RsTypeKind) -> Safe
602602
}
603603
}
604604
BridgeRsTypeKind::StdString { .. } => Safety::Safe,
605-
BridgeRsTypeKind::DynCallable(dyn_callable) => {
606-
callable_safety(db, &dyn_callable.param_types, &dyn_callable.return_type)
605+
BridgeRsTypeKind::Callable(callable) => {
606+
callable_safety(db, &callable.param_types, &callable.return_type)
607607
}
608608
BridgeRsTypeKind::C9Co { result_type, .. } => {
609609
// A Co<T> logically produces a T, so it is unsafe iff T is unsafe.
@@ -785,6 +785,17 @@ fn generate_rs_api_impl_includes(
785785
"util/c9/internal/rust/co_crubit_abi.h".into(),
786786
));
787787
}
788+
BridgeRsTypeKind::Callable(callable)
789+
if callable.backing_type == BackingType::AnyInvocable =>
790+
{
791+
internal_includes.insert(CcInclude::SupportLibHeader(
792+
crubit_support_path_format.clone(),
793+
"bridge.h".into(),
794+
));
795+
internal_includes.insert(CcInclude::user_header(
796+
"third_party/absl/functional/any_invocable_crubit_abi.h".into(),
797+
));
798+
}
788799
_ => {
789800
internal_includes.insert(CcInclude::SupportLibHeader(
790801
crubit_support_path_format.clone(),
@@ -1013,12 +1024,12 @@ fn crubit_abi_type(db: &BindingsGenerator, rs_type_kind: RsTypeKind) -> Result<C
10131024
Ok(CrubitAbiType::Pair(Rc::from(first_abi), Rc::from(second_abi)))
10141025
}
10151026
BridgeRsTypeKind::StdString { in_cc_std } => Ok(CrubitAbiType::StdString { in_cc_std }),
1016-
BridgeRsTypeKind::DynCallable(dyn_callable) => {
1027+
BridgeRsTypeKind::Callable(callable) => {
10171028
ensure!(
10181029
db.ir().target_crubit_features(&original_type.owning_target).contains(CrubitFeature::Callables),
10191030
"Callables require the `callables` feature, but target `{:?}` does not have it enabled.", original_type.owning_target,
10201031
);
1021-
generate_dyn_callable::dyn_callable_crubit_abi_type(db, &dyn_callable)
1032+
generate_dyn_callable::dyn_callable_crubit_abi_type(db, &callable)
10221033
}
10231034
BridgeRsTypeKind::C9Co { result_type, .. } => {
10241035
let result_type_tokens = if result_type.is_void() {

0 commit comments

Comments
 (0)