33// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44
55use arc_anyhow:: { bail, Result } ;
6- use crubit_abi_type:: { CrubitAbiType , CrubitAbiTypeToCppExprTokens , CrubitAbiTypeToCppTokens } ;
6+ use crubit_abi_type:: {
7+ CrubitAbiType , CrubitAbiTypeToCppExprTokens , CrubitAbiTypeToCppTokens ,
8+ CrubitAbiTypeToRustExprTokens , CrubitAbiTypeToRustTokens ,
9+ } ;
710use database:: db:: BindingsGenerator ;
811use database:: rs_snippet:: { BackingType , Callable , FnTrait , PassingConvention , RsTypeKind } ;
912use proc_macro2:: TokenStream ;
@@ -44,7 +47,7 @@ pub fn dyn_callable_crubit_abi_type(
4447 )
4548 } ,
4649 BackingType :: AnyInvocable => {
47- let make_cpp_invoker_tokens = generate_make_cpp_invoker_tokens ( ) ?;
50+ let make_cpp_invoker_tokens = generate_make_cpp_invoker_tokens ( db , callable ) ?;
4851 quote ! {
4952 :: any_invocable:: AnyInvocableAbi :: <#dyn_fn_spelling>:: new(
5053 #on_empty_tokens,
@@ -106,7 +109,7 @@ pub fn dyn_callable_crubit_abi_type(
106109 } )
107110}
108111
109- /// Generates the function pointer object that DynCallable will use in operator()().
112+ /// Generates the function pointer object that the callable will use in operator()().
110113///
111114/// This will often produce tokens of the form:
112115///
@@ -142,7 +145,7 @@ fn generate_invoker_function_pointer(
142145 let param_ident = & param_idents[ i] ;
143146
144147 match param_ty. passing_convention ( ) {
145- PassingConvention :: AbiCompatible => {
148+ PassingConvention :: AbiCompatible | PassingConvention :: OwnedPtr => {
146149 arg_exprs. push ( quote ! { #param_ident } ) ;
147150 }
148151 PassingConvention :: LayoutCompatible => {
@@ -159,23 +162,22 @@ fn generate_invoker_function_pointer(
159162 let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens ( & crubit_abi_type) ;
160163 let arg_ident = format_ident ! ( "bridge_param_{i}" ) ;
161164 arg_transforms. extend ( quote ! {
162- unsigned char #arg_ident[ #crubit_abi_type_tokens:: kSize] ;
163- :: crubit:: internal:: Encode ( #crubit_abi_type_expr_tokens, #arg_ident, #param_ident) ;
164- } ) ;
165+ unsigned char #arg_ident[ #crubit_abi_type_tokens:: kSize] ;
166+ :: crubit:: internal:: Encode ( #crubit_abi_type_expr_tokens, #arg_ident, #param_ident) ;
167+ } ) ;
165168 arg_exprs. push ( quote ! { #arg_ident } ) ;
166169 }
167170 PassingConvention :: Ctor => {
168171 bail ! ( "Ctor not supported" ) ;
169172 }
170- PassingConvention :: OwnedPtr => {
171- bail ! ( "OwnedPtr not supported" ) ;
172- }
173- PassingConvention :: Void => unreachable ! ( "parameter types cannot be void" ) ,
173+ PassingConvention :: Void => bail ! ( "parameter types cannot be void" ) ,
174174 }
175175 }
176176
177177 let out_param_arg = match callable. return_type . passing_convention ( ) {
178- PassingConvention :: AbiCompatible | PassingConvention :: Void => None ,
178+ PassingConvention :: AbiCompatible
179+ | PassingConvention :: Void
180+ | PassingConvention :: OwnedPtr => None ,
179181 PassingConvention :: LayoutCompatible => {
180182 arg_transforms. extend ( quote ! {
181183 :: crubit:: Slot <#cpp_return_type> out;
@@ -193,17 +195,14 @@ fn generate_invoker_function_pointer(
193195 PassingConvention :: Ctor => {
194196 bail ! ( "Ctor not supported" ) ;
195197 }
196- PassingConvention :: OwnedPtr => {
197- bail ! ( "OwnedPtr not supported" ) ;
198- }
199198 } ;
200199
201200 let mut invoke_ffi_and_transform_to_cpp = quote ! {
202201 #invoker_ident( state #( , #arg_exprs) * #out_param_arg) ;
203202 } ;
204203
205204 match callable. return_type . passing_convention ( ) {
206- PassingConvention :: AbiCompatible => {
205+ PassingConvention :: AbiCompatible | PassingConvention :: OwnedPtr => {
207206 // Return the result.
208207 invoke_ffi_and_transform_to_cpp = quote ! {
209208 return #invoke_ffi_and_transform_to_cpp
@@ -221,16 +220,13 @@ fn generate_invoker_function_pointer(
221220 let crubit_abi_type_tokens = CrubitAbiTypeToCppTokens ( & crubit_abi_type) ;
222221 let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens ( & crubit_abi_type) ;
223222 invoke_ffi_and_transform_to_cpp. extend ( quote ! {
224- // Because our bridge buffer is named `out`
225- return :: crubit:: internal:: Decode <#crubit_abi_type_tokens>( #crubit_abi_type_expr_tokens, out) ;
226- } ) ;
223+ // Because our bridge buffer is named `out`
224+ return :: crubit:: internal:: Decode <#crubit_abi_type_tokens>( #crubit_abi_type_expr_tokens, out) ;
225+ } ) ;
227226 }
228227 PassingConvention :: Ctor => {
229228 bail ! ( "Ctor not supported" ) ;
230229 }
231- PassingConvention :: OwnedPtr => {
232- bail ! ( "OwnedPtr not supported" ) ;
233- }
234230 PassingConvention :: Void => {
235231 // No need to return anything.
236232 }
@@ -251,6 +247,148 @@ fn generate_invoker_function_pointer(
251247}
252248
253249/// Generates the `make_cpp_invoker` function for AnyInvocable.
254- fn generate_make_cpp_invoker_tokens ( ) -> Result < TokenStream > {
255- bail ! ( "AnyInvocable is not yet supported" )
250+ ///
251+ /// It's a closure that takes a manager and an invoker, and produces a boxed dyn fn that uses the
252+ /// manager and invoker to do the actual work.
253+ ///
254+ /// The produced function needs to know how to convert values to and from C++.
255+ fn generate_make_cpp_invoker_tokens (
256+ db : & BindingsGenerator ,
257+ callable : & Callable ,
258+ ) -> Result < TokenStream > {
259+ let param_idents =
260+ ( 0 ..callable. param_types . len ( ) ) . map ( |i| format_ident ! ( "param_{i}" ) ) . collect :: < Vec < _ > > ( ) ;
261+ let rust_param_types = callable. param_types . iter ( ) . map ( |param_ty| param_ty. to_token_stream ( db) ) ;
262+ let rust_return_type_fragment = callable. rust_return_type_fragment ( db) ;
263+
264+ let mut c_param_types = Vec :: with_capacity ( callable. param_types . len ( ) ) ;
265+ let mut arg_exprs = Vec :: with_capacity ( callable. param_types . len ( ) ) ;
266+ // We are the caller
267+ for ( i, param_ty) in callable. param_types . iter ( ) . enumerate ( ) {
268+ let param_ident = & param_idents[ i] ;
269+
270+ match param_ty. passing_convention ( ) {
271+ PassingConvention :: AbiCompatible => {
272+ c_param_types. push ( param_ty. to_token_stream ( db) ) ;
273+ arg_exprs. push ( quote ! { #param_ident } ) ;
274+ }
275+ PassingConvention :: LayoutCompatible => {
276+ let param_ty_tokens = param_ty. to_token_stream ( db) ;
277+ c_param_types. push ( quote ! { & mut #param_ty_tokens } ) ;
278+ arg_exprs. push ( quote ! { & mut #param_ident } ) ;
279+ }
280+ PassingConvention :: ComposablyBridged => {
281+ let crubit_abi_type = db. crubit_abi_type ( param_ty. clone ( ) ) ?;
282+ let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens ( & crubit_abi_type) ;
283+ let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens ( & crubit_abi_type) ;
284+ // For arguments that are bridge types, we encode the
285+ // Rust value into a buffer and then the argument is a pointer to that buffer.
286+ c_param_types. push ( quote ! { * const u8 } ) ;
287+ arg_exprs. push ( quote ! {
288+ :: bridge_rust:: unstable_encode!( @ #crubit_abi_type_expr_tokens, #crubit_abi_type_tokens, #param_ident)
289+ . as_ptr( ) as * const u8
290+ } ) ;
291+ }
292+ PassingConvention :: Ctor => {
293+ bail ! ( "Ctor not supported" ) ;
294+ }
295+ PassingConvention :: OwnedPtr => {
296+ c_param_types. push ( param_ty. to_token_stream_with_owned_ptr_type ( db) ) ;
297+ arg_exprs. push ( quote ! {
298+ // SAFETY: Transmuting from a repr(transparent) struct that wraps the pointer.
299+ unsafe { :: core:: mem:: transmute( #param_ident) }
300+ } ) ;
301+ }
302+ PassingConvention :: Void => bail ! ( "parameter types cannot be void" ) ,
303+ }
304+ }
305+
306+ // What the extern "C" function should return.
307+ let mut c_return_type_fragment = None ;
308+ // Set c_return_type_fragment, or push an out param, or nothing if void.
309+ match callable. return_type . passing_convention ( ) {
310+ PassingConvention :: AbiCompatible => {
311+ let c_return_type = callable. return_type . to_token_stream ( db) ;
312+ c_return_type_fragment = Some ( quote ! { -> #c_return_type } ) ;
313+ }
314+ PassingConvention :: Void => { }
315+ PassingConvention :: LayoutCompatible => {
316+ let return_type_tokens = callable. return_type . to_token_stream ( db) ;
317+ c_param_types. push ( quote ! { * mut #return_type_tokens } ) ;
318+ arg_exprs. push ( quote ! { & raw mut out } ) ;
319+ }
320+ PassingConvention :: ComposablyBridged => {
321+ c_param_types. push ( quote ! { * mut u8 } ) ;
322+ arg_exprs. push ( quote ! { & raw mut out } ) ;
323+ }
324+ PassingConvention :: Ctor => {
325+ bail ! ( "Ctor not supported" ) ;
326+ }
327+ PassingConvention :: OwnedPtr => {
328+ let c_return_type = callable. return_type . to_token_stream_with_owned_ptr_type ( db) ;
329+ c_return_type_fragment = Some ( quote ! { -> #c_return_type } ) ;
330+ }
331+ } ;
332+
333+ let mut invoke_ffi_and_transform_to_rust = quote ! {
334+ unsafe { c_invoker( managed. state. get( ) #( , #arg_exprs) * ) }
335+ } ;
336+
337+ match callable. return_type . passing_convention ( ) {
338+ PassingConvention :: AbiCompatible => {
339+ // invoke_ffi_and_transform_to_rust is already a trailing expr.
340+ }
341+ PassingConvention :: LayoutCompatible => {
342+ invoke_ffi_and_transform_to_rust = quote ! {
343+ let out = :: core:: mem:: MaybeUninit :: uninit( ) ;
344+ #invoke_ffi_and_transform_to_rust;
345+ unsafe { out. assume_init( ) }
346+ }
347+ }
348+ PassingConvention :: ComposablyBridged => {
349+ let crubit_abi_type = db. crubit_abi_type ( callable. return_type . as_ref ( ) . clone ( ) ) ?;
350+ let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens ( & crubit_abi_type) ;
351+ let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens ( & crubit_abi_type) ;
352+ invoke_ffi_and_transform_to_rust. extend ( quote ! {
353+ :: bridge_rust:: unstable_return!( @ #crubit_abi_type_expr_tokens, #crubit_abi_type_tokens, |out| {
354+ #invoke_ffi_and_transform_to_rust
355+ } )
356+ } ) ;
357+ }
358+ PassingConvention :: Ctor => {
359+ bail ! ( "Ctor not supported" ) ;
360+ }
361+ PassingConvention :: OwnedPtr => {
362+ invoke_ffi_and_transform_to_rust = quote ! {
363+ // SAFETY: Transmuting to a repr(transparent) struct that wraps the pointer.
364+ unsafe { :: core:: mem:: transmute( #invoke_ffi_and_transform_to_rust) }
365+ } ;
366+ }
367+ PassingConvention :: Void => {
368+ // Append semicolon to the statement.
369+ invoke_ffi_and_transform_to_rust = quote ! {
370+ #invoke_ffi_and_transform_to_rust;
371+ }
372+ }
373+ }
374+
375+ let dyn_fn_spelling = callable. dyn_fn_spelling ( db) ;
376+
377+ Ok ( quote ! {
378+ |managed: :: any_invocable:: ManagedState ,
379+ invoker: unsafe extern "C" fn ( ) | -> :: alloc:: boxed:: Box <#dyn_fn_spelling> {
380+ let c_invoker = unsafe {
381+ :: core:: mem:: transmute:: <
382+ unsafe extern "C" fn ( ) ,
383+ unsafe extern "C" fn (
384+ * mut :: any_invocable:: TypeErasedState
385+ #( , #c_param_types ) *
386+ ) #c_return_type_fragment
387+ >( invoker)
388+ } ;
389+ :: alloc:: boxed:: Box :: new( move |#( #param_idents: #rust_param_types ) , * | #rust_return_type_fragment {
390+ #invoke_ffi_and_transform_to_rust
391+ } )
392+ }
393+ } )
256394}
0 commit comments