diff --git a/Cargo.toml b/Cargo.toml index 0fd15f2..8090a47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ members = [ "fixtures/simple-fns", "fixtures/trait-methods", "fixtures/trait-interfaces", + "fixtures/dart_async", #"fixtures/*", ] diff --git a/fixtures/dart_async/Cargo.toml b/fixtures/dart_async/Cargo.toml index fe397ac..33bd926 100644 --- a/fixtures/dart_async/Cargo.toml +++ b/fixtures/dart_async/Cargo.toml @@ -11,16 +11,18 @@ name = "dart_async" crate-type = ["lib", "cdylib"] [dependencies] -uniffi = { workspace = true, features = ["tokio"]} -tokio = { version = "1.24.1", features = ["time"] } -thiserror = "1.0" +uniffi = { workspace = true, features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] } +async-trait = "0.1" +futures = "0.3" +tokio = { version = "1.38.2", features = ["time", "sync"] } +once_cell = "1.18.0" +thiserror = "2" [build-dependencies] +uniffi = { workspace = true, features = ["build", "scaffolding-ffi-buffer-fns"] } uniffi-dart = { path = "../../", features = ["build"] } [dev-dependencies] +uniffi = { workspace = true, features = ["bindgen-tests"] } uniffi-dart = { path = "../../", features = ["bindgen-tests"] } -uniffi = { workspace = true, features = [ - "bindgen-tests", -] } anyhow = "1" \ No newline at end of file diff --git a/fixtures/dart_async/src/api.udl b/fixtures/dart_async/src/api.udl index 1b69db7..d43c39b 100644 --- a/fixtures/dart_async/src/api.udl +++ b/fixtures/dart_async/src/api.udl @@ -1,7 +1,7 @@ -namespace dart_async { +namespace dart_async { // UDL-defined async functions (testing UDL vs proc-macro async support) [Async] - boolean udl_always_ready(); + boolean always_ready(); }; // UDL-defined async trait interface diff --git a/fixtures/dart_async/src/lib.rs b/fixtures/dart_async/src/lib.rs index 52624f0..f37a60e 100644 --- a/fixtures/dart_async/src/lib.rs +++ b/fixtures/dart_async/src/lib.rs @@ -1,5 +1,3 @@ -use uniffi; - use std::{ future::Future, pin::Pin, @@ -9,6 +7,9 @@ use std::{ time::Duration, }; +use futures::future::{AbortHandle, Abortable, Aborted}; +use once_cell::sync::Lazy; + /// Non-blocking timer future. pub struct TimerFuture { shared_state: Arc>, @@ -24,7 +25,6 @@ impl Future for TimerFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut shared_state = self.shared_state.lock().unwrap(); - if shared_state.completed { Poll::Ready(()) } else { @@ -42,14 +42,11 @@ impl TimerFuture { })); let thread_shared_state = shared_state.clone(); - // Let's mimic an event coming from somewhere else, like the system. thread::spawn(move || { thread::sleep(duration); - let mut shared_state: MutexGuard<_> = thread_shared_state.lock().unwrap(); shared_state.completed = true; - if let Some(waker) = shared_state.waker.take() { waker.wake(); } @@ -59,7 +56,7 @@ impl TimerFuture { } } -/// Non-blocking timer future. +/// Non-blocking timer future that intentionally misbehaves. pub struct BrokenTimerFuture { shared_state: Arc>, } @@ -69,7 +66,6 @@ impl Future for BrokenTimerFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut shared_state = self.shared_state.lock().unwrap(); - if shared_state.completed { Poll::Ready(()) } else { @@ -87,21 +83,15 @@ impl BrokenTimerFuture { })); let thread_shared_state = shared_state.clone(); - // Let's mimic an event coming from somewhere else, like the system. thread::spawn(move || { thread::sleep(duration); - let mut shared_state: MutexGuard<_> = thread_shared_state.lock().unwrap(); shared_state.completed = true; - if let Some(waker) = shared_state.waker.take() { // Do not consume `waker`. waker.wake_by_ref(); - - // And this is the important part. We are going to call - // `wake()` a second time. That's incorrect, but that's on - // purpose, to see how foreign languages will react. + // And this is the important part. We are going to call `wake()` a second time. if fail_after.is_zero() { waker.wake(); } else { @@ -122,35 +112,37 @@ pub fn greet(who: String) -> String { format!("Hello, {who}") } -#[uniffi::export] +/// Async function that is immediately ready. Declared in the UDL to ensure UDL async works. pub async fn always_ready() -> bool { true } #[uniffi::export] -pub async fn void_function() {} +pub async fn void() {} #[uniffi::export] pub async fn say() -> String { TimerFuture::new(Duration::from_secs(2)).await; - "Hello, Future!".to_string() } #[uniffi::export] pub async fn say_after(ms: u16, who: String) -> String { TimerFuture::new(Duration::from_millis(ms.into())).await; - format!("Hello, {who}!") } #[uniffi::export] pub async fn sleep(ms: u16) -> bool { TimerFuture::new(Duration::from_millis(ms.into())).await; - true } +#[uniffi::export] +pub async fn sleep_no_return(ms: u16) { + TimerFuture::new(Duration::from_millis(ms.into())).await; +} + // Our error. #[derive(thiserror::Error, uniffi::Error, Debug)] pub enum MyError { @@ -168,10 +160,14 @@ pub async fn fallible_me(do_fail: bool) -> Result { } } -#[uniffi::export(async_runtime = "tokio")] -pub async fn say_after_with_tokio(ms: u16, who: String) -> String { - tokio::time::sleep(Duration::from_millis(ms.into())).await; - format!("Hello, {who} (with Tokio)!") +// An async function returning a struct that can throw. +#[uniffi::export] +pub async fn fallible_struct(do_fail: bool) -> Result, MyError> { + if do_fail { + Err(MyError::Foo) + } else { + Ok(new_megaphone()) + } } #[derive(uniffi::Record)] @@ -185,6 +181,7 @@ pub async fn new_my_record(a: String, b: u32) -> MyRecord { MyRecord { a, b } } +/// Non-blocking timer future used to test callback cancellation. #[uniffi::export] pub async fn broken_sleep(ms: u16, fail_after: u16) { BrokenTimerFuture::new( @@ -194,37 +191,7 @@ pub async fn broken_sleep(ms: u16, fail_after: u16) { .await; } -// UDL-defined async function -pub async fn udl_always_ready() -> bool { - true -} - -// UDL-defined async trait -pub trait SayAfterUdlTrait: Send + Sync { - async fn say_after(&self, ms: u16, who: String) -> String; -} - -// UDL-defined object with async methods -pub struct UdlMegaphone; - -impl UdlMegaphone { - pub async fn new() -> Self { - TimerFuture::new(Duration::from_millis(0)).await; - UdlMegaphone - } - - pub async fn secondary() -> Self { - TimerFuture::new(Duration::from_millis(0)).await; - UdlMegaphone - } - - pub async fn say_after(&self, ms: u16, who: String) -> String { - TimerFuture::new(Duration::from_millis(ms.into())).await; - format!("Hello, {who} (from UDL Megaphone)!").to_uppercase() - } -} - -// Proc-macro-defined object with async methods (Megaphone) +/// Proc-macro-defined object with async methods (Megaphone) #[derive(uniffi::Object)] pub struct Megaphone; @@ -246,19 +213,16 @@ impl Megaphone { /// Async method that yells something after a certain time pub async fn say_after(self: Arc, ms: u16, who: String) -> String { - TimerFuture::new(Duration::from_millis(ms.into())).await; - format!("Hello, {who}!").to_uppercase() + say_after(ms, who).await.to_uppercase() } /// Async method without any extra arguments pub async fn silence(&self) -> String { - TimerFuture::new(Duration::from_millis(100)).await; String::new() } /// Async method that can throw pub async fn fallible_me(self: Arc, do_fail: bool) -> Result { - TimerFuture::new(Duration::from_millis(10)).await; if do_fail { Err(MyError::Foo) } else { @@ -267,7 +231,7 @@ impl Megaphone { } } -// Mixed async/sync methods on the same object (using tokio runtime) +/// Mixed async/sync methods on the same object (using tokio runtime) #[uniffi::export(async_runtime = "tokio")] impl Megaphone { /// Sync method that yells something immediately @@ -277,48 +241,46 @@ impl Megaphone { /// Async method using Tokio's timer pub async fn say_after_with_tokio(self: Arc, ms: u16, who: String) -> String { - tokio::time::sleep(Duration::from_millis(ms.into())).await; - format!("Hello, {who} (with Tokio)!").to_uppercase() + say_after_with_tokio(ms, who).await.to_uppercase() } } -// Additional async functions that work with objects +/// Sync function that generates a new `Megaphone`. #[uniffi::export] pub fn new_megaphone() -> Arc { Arc::new(Megaphone) } +/// Async function that generates a new `Megaphone`. #[uniffi::export] pub async fn async_new_megaphone() -> Arc { - Arc::new(Megaphone) + new_megaphone() } +/// Async function that possibly generates a new `Megaphone`. #[uniffi::export] -pub async fn async_maybe_new_megaphone(should_create: bool) -> Option> { - TimerFuture::new(Duration::from_millis(10)).await; - if should_create { - Some(Arc::new(Megaphone)) +pub async fn async_maybe_new_megaphone(y: bool) -> Option> { + if y { + Some(new_megaphone()) } else { None } } +/// Async function that inputs `Megaphone`. #[uniffi::export] pub async fn say_after_with_megaphone(megaphone: Arc, ms: u16, who: String) -> String { megaphone.say_after(ms, who).await } -#[uniffi::export] -pub async fn fallible_struct(do_fail: bool) -> Result, MyError> { - TimerFuture::new(Duration::from_millis(10)).await; - if do_fail { - Err(MyError::Foo) - } else { - Ok(Arc::new(Megaphone)) - } +/// Async function that uses tokio runtime. +#[uniffi::export(async_runtime = "tokio")] +pub async fn say_after_with_tokio(ms: u16, who: String) -> String { + tokio::time::sleep(Duration::from_millis(ms.into())).await; + format!("Hello, {who} (with Tokio)!") } -// Fallible async constructor object +/// Fallible async constructor object #[derive(uniffi::Object)] pub struct FallibleMegaphone; @@ -326,9 +288,184 @@ pub struct FallibleMegaphone; impl FallibleMegaphone { #[uniffi::constructor] pub async fn new() -> Result, MyError> { - TimerFuture::new(Duration::from_millis(10)).await; - Err(MyError::Foo) // Always fails for testing + Err(MyError::Foo) + } +} + +/// Async runtime example that uses shared state to test timeouts. +#[derive(uniffi::Record)] +pub struct SharedResourceOptions { + pub release_after_ms: u16, + pub timeout_ms: u16, +} + +// Our error for async resource usage. +#[derive(thiserror::Error, uniffi::Error, Debug)] +pub enum AsyncError { + #[error("Timeout")] + Timeout, +} + +#[uniffi::export(async_runtime = "tokio")] +pub async fn use_shared_resource(options: SharedResourceOptions) -> Result<(), AsyncError> { + use tokio::{ + sync::Mutex, + time::{sleep, timeout}, + }; + + static MUTEX: Lazy> = Lazy::new(|| Mutex::new(())); + + let _guard = timeout( + Duration::from_millis(options.timeout_ms.into()), + MUTEX.lock(), + ) + .await + .map_err(|_| AsyncError::Timeout)?; + + sleep(Duration::from_millis(options.release_after_ms.into())).await; + Ok(()) +} + +// Example of a trait with async methods. +#[uniffi::export] +#[async_trait::async_trait] +pub trait SayAfterTrait: Send + Sync { + async fn say_after(&self, ms: u16, who: String) -> String; +} + +// Example of async trait defined in the UDL file. +#[async_trait::async_trait] +pub trait SayAfterUdlTrait: Send + Sync { + async fn say_after(&self, ms: u16, who: String) -> String; +} + +struct SayAfterImpl1; +struct SayAfterImpl2; + +#[async_trait::async_trait] +impl SayAfterTrait for SayAfterImpl1 { + async fn say_after(&self, ms: u16, who: String) -> String { + say_after(ms, who).await + } +} + +#[async_trait::async_trait] +impl SayAfterTrait for SayAfterImpl2 { + async fn say_after(&self, ms: u16, who: String) -> String { + say_after(ms, who).await + } +} + +#[uniffi::export] +pub fn get_say_after_traits() -> Vec> { + vec![Arc::new(SayAfterImpl1), Arc::new(SayAfterImpl2)] +} + +#[async_trait::async_trait] +impl SayAfterUdlTrait for SayAfterImpl1 { + async fn say_after(&self, ms: u16, who: String) -> String { + say_after(ms, who).await + } +} + +#[async_trait::async_trait] +impl SayAfterUdlTrait for SayAfterImpl2 { + async fn say_after(&self, ms: u16, who: String) -> String { + say_after(ms, who).await + } +} + +#[uniffi::export] +pub fn get_say_after_udl_traits() -> Vec> { + vec![Arc::new(SayAfterImpl1), Arc::new(SayAfterImpl2)] +} + +/// UDL-defined object with async methods. +pub struct UdlMegaphone; + +impl UdlMegaphone { + pub async fn new() -> Self { + TimerFuture::new(Duration::from_millis(0)).await; + Self {} } + + pub async fn secondary() -> Self { + TimerFuture::new(Duration::from_millis(0)).await; + Self {} + } + + pub async fn say_after(&self, ms: u16, who: String) -> String { + TimerFuture::new(Duration::from_millis(ms.into())).await; + format!("Hello, {who} (from UDL Megaphone)!").to_uppercase() + } +} + +// Async callback interface implemented in foreign code. +#[uniffi::export(with_foreign)] +#[async_trait::async_trait] +pub trait AsyncParser: Send + Sync { + // Simple async method + async fn as_string(&self, delay_ms: i32, value: i32) -> String; + // Async method that can throw + async fn try_from_string(&self, delay_ms: i32, value: String) -> Result; + // Void return, which requires special handling + async fn delay(&self, delay_ms: i32); + // Void return that can also throw + async fn try_delay(&self, delay_ms: String) -> Result<(), ParserError>; +} + +#[derive(thiserror::Error, uniffi::Error, Debug)] +pub enum ParserError { + #[error("NotAnInt")] + NotAnInt, + #[error("UnexpectedError")] + UnexpectedError, +} + +impl From for ParserError { + fn from(_: uniffi::UnexpectedUniFFICallbackError) -> Self { + Self::UnexpectedError + } +} + +#[uniffi::export] +pub async fn as_string_using_trait(obj: Arc, delay_ms: i32, value: i32) -> String { + obj.as_string(delay_ms, value).await +} + +#[uniffi::export] +pub async fn try_from_string_using_trait( + obj: Arc, + delay_ms: i32, + value: String, +) -> Result { + obj.try_from_string(delay_ms, value).await +} + +#[uniffi::export] +pub async fn delay_using_trait(obj: Arc, delay_ms: i32) { + obj.delay(delay_ms).await +} + +#[uniffi::export] +pub async fn try_delay_using_trait( + obj: Arc, + delay_ms: String, +) -> Result<(), ParserError> { + obj.try_delay(delay_ms).await +} + +#[uniffi::export] +pub async fn cancel_delay_using_trait(obj: Arc, delay_ms: i32) { + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + thread::spawn(move || { + // Simulate a different thread aborting the process + thread::sleep(Duration::from_millis(1)); + abort_handle.abort(); + }); + + let future = Abortable::new(obj.delay(delay_ms), abort_registration); + assert_eq!(future.await, Err(Aborted)); } uniffi::include_scaffolding!("api"); diff --git a/fixtures/dart_async/test/futures_test.dart b/fixtures/dart_async/test/futures_test.dart index 77884a1..60b80af 100644 --- a/fixtures/dart_async/test/futures_test.dart +++ b/fixtures/dart_async/test/futures_test.dart @@ -13,7 +13,7 @@ void main() { ensureInitialized(); test('greet', () async { - final result = await greet("Somebody"); + final result = greet("Somebody"); expect(result, "Hello, Somebody"); }); @@ -28,8 +28,7 @@ void main() { test('void', () async { final time = await measureTime(() async { - await voidFunction(); - //expect(result, null); + await void_(); }); // Less than or equal to time expect(time.inMilliseconds <= 10, true); @@ -78,7 +77,7 @@ void main() { test('fallible_function_and_method', () async { final time1 = await measureTime(() async { try { - fallibleMe(false); + await fallibleMe(false); expect(true, true); } catch (exception) { expect(false, true); // should never be reached @@ -88,7 +87,7 @@ void main() { final time2 = await measureTime(() async { try { - fallibleMe(true); + await fallibleMe(true); expect(false, true); // should never be reached } catch (exception) { expect(true, true); @@ -121,7 +120,7 @@ void main() { test('udl_async_function', () async { final time = await measureTime(() async { - final result = await udlAlwaysReady(); + final result = await alwaysReady(); expect(result, true); }); expect(time.inMilliseconds < 100, true); @@ -129,7 +128,7 @@ void main() { test('proc_macro_megaphone_async_constructor', () async { final time = await measureTime(() async { - final megaphone = await Megaphone(); + final megaphone = await Megaphone.new_(); expect(megaphone, isNotNull); }); expect(time.inMilliseconds < 100, true); @@ -144,7 +143,7 @@ void main() { }); test('proc_macro_megaphone_async_methods', () async { - final megaphone = await Megaphone(); + final megaphone = await Megaphone.new_(); // Test async method with timing final time = await measureTime(() async { @@ -158,14 +157,11 @@ void main() { final result = await megaphone.silence(); expect(result, ''); }); - expect( - silenceTime.inMilliseconds >= 100 && silenceTime.inMilliseconds < 200, - true, - ); + expect(silenceTime.inMilliseconds < 50, true); }); test('proc_macro_megaphone_sync_method', () async { - final megaphone = await Megaphone(); + final megaphone = await Megaphone.new_(); // Test sync method (should be immediate) final time = await measureTime(() async { @@ -176,7 +172,7 @@ void main() { }); test('proc_macro_megaphone_tokio_method', () async { - final megaphone = await Megaphone(); + final megaphone = await Megaphone.new_(); final time = await measureTime(() async { final result = await megaphone.sayAfterWithTokio(100, 'Charlie'); @@ -186,7 +182,7 @@ void main() { }); test('proc_macro_megaphone_fallible_method', () async { - final megaphone = await Megaphone(); + final megaphone = await Megaphone.new_(); // Test success case final result = await megaphone.fallibleMe(false); @@ -204,7 +200,7 @@ void main() { test('udl_megaphone_async_constructors', () async { // Test primary constructor final time1 = await measureTime(() async { - final udlMegaphone = await UdlMegaphone(); + final udlMegaphone = await UdlMegaphone.new_(); expect(udlMegaphone, isNotNull); }); expect(time1.inMilliseconds < 100, true); @@ -218,7 +214,7 @@ void main() { }); test('udl_megaphone_async_method', () async { - final udlMegaphone = await UdlMegaphone(); + final udlMegaphone = await UdlMegaphone.new_(); final time = await measureTime(() async { final result = await udlMegaphone.sayAfter(100, 'Dave'); @@ -245,7 +241,7 @@ void main() { }); test('async_function_with_object_parameter', () async { - final megaphone = await Megaphone(); + final megaphone = await Megaphone.new_(); final time = await measureTime(() async { final result = await sayAfterWithMegaphone(megaphone, 100, 'Eve'); @@ -271,7 +267,7 @@ void main() { test('fallible_async_constructor', () async { // This constructor always fails try { - await FallibleMegaphone(); + await FallibleMegaphone.new_(); expect(false, true); // Should never reach here } catch (e) { expect(true, true); // Expected to throw diff --git a/src/gen/callback_interface.rs b/src/gen/callback_interface.rs index 136f3d5..3f64eac 100644 --- a/src/gen/callback_interface.rs +++ b/src/gen/callback_interface.rs @@ -1,7 +1,12 @@ use crate::gen::CodeType; use genco::prelude::*; +use heck::ToUpperCamelCase; +use std::collections::BTreeSet; use uniffi_bindgen::interface::Type; -use uniffi_bindgen::interface::{AsType, Method}; +use uniffi_bindgen::interface::{ + ffi::{FfiStruct, FfiType}, + AsType, Method, +}; use crate::gen::oracle::{AsCodeType, DartCodeOracle}; use crate::gen::render::AsRenderable; @@ -85,6 +90,42 @@ pub fn generate_callback_interface( let ffi_conv_name = &DartCodeOracle::class_name(ffi_converter_name); let init_fn_name = &format!("init{callback_name}VTable"); + let mut seen_async_structs = BTreeSet::new(); + let mut seen_async_callbacks = BTreeSet::new(); + let mut async_struct_defs: Vec = Vec::new(); + let mut async_completion_typedefs: Vec = Vec::new(); + + for method in methods { + if method.is_async() { + let struct_def = method.foreign_future_ffi_result_struct(); + let struct_name = struct_def.name().to_string(); + + if seen_async_structs.insert(struct_name.clone()) { + async_struct_defs.push(generate_foreign_future_struct_definition( + &struct_def, + type_helper, + )); + } + + let completion_name = foreign_future_completion_name(method); + if seen_async_callbacks.insert(completion_name.clone()) { + async_completion_typedefs.push(generate_foreign_future_completion_typedef( + &completion_name, + &struct_name, + )); + } + } + } + + let async_support = if !async_struct_defs.is_empty() || !async_completion_typedefs.is_empty() { + quote! { + $(for typedef in &async_completion_typedefs => $typedef) + $(for struct_def in &async_struct_defs => $struct_def) + } + } else { + quote!() + }; + let tokens = quote! { // This is the abstract class to be implemented abstract class $cls_name { @@ -132,6 +173,9 @@ pub fn generate_callback_interface( } } + // Additional support definitions for async callbacks + $async_support + // We must define callback signatures $(generate_callback_methods_signatures(cls_name, methods, type_helper)) }; @@ -155,7 +199,14 @@ fn generate_callback_methods_definitions( }) .collect::>(); - let ret_type = if let Some(ret) = method.return_type() { + let ret_type = if method.is_async() { + if let Some(ret) = method.return_type() { + let rendered = ret.as_renderable().render_type(ret, type_helper); + quote!(Future<$rendered>) + } else { + quote!(Future) + } + } else if let Some(ret) = method.return_type() { ret.as_renderable().render_type(ret, type_helper) } else { quote!(void) @@ -176,24 +227,54 @@ fn generate_callback_methods_signatures( //let method_name = DartCodeOracle::fn_name(method.name()); let ffi_method_type = format!("UniffiCallbackInterface{callback_name}Method{method_index}"); - let dart_method_type = format!("UniffiCallbackInterface{callback_name}Method{method_index}Dart"); - let method_return_type = if let Some(ret) = method.return_type() { - DartCodeOracle::native_type_label(Some(ret), type_helper.get_ci()) + let arg_native_types: Vec = method + .arguments() + .iter() + .map(|arg| { + DartCodeOracle::native_type_label(Some(&arg.as_type()), type_helper.get_ci()) + }) + .collect(); + + let arg_dart_types: Vec = method + .arguments() + .iter() + .map(|arg| { + DartCodeOracle::native_dart_type_label(Some(&arg.as_type()), type_helper.get_ci()) + }) + .collect(); + + if method.is_async() { + let completion_base = foreign_future_completion_name(method); + let completion_native = format!("Uniffi{}", completion_base.to_upper_camel_case()); + let completion_pointer = format!("Pointer>", completion_native); + + tokens.append(quote! { + typedef $ffi_method_type = Void Function( + Uint64, $(for arg in &arg_native_types => $arg,) + $(&completion_pointer), Uint64, Pointer); + typedef $dart_method_type = void Function( + int, $(for arg in &arg_dart_types => $arg,) + $(&completion_pointer), int, Pointer); + }); } else { - quote!(Void) - }; - - tokens.append(quote! { - typedef $ffi_method_type = Void Function( - Uint64, $(for arg in &method.arguments() => $(DartCodeOracle::native_type_label(Some(&arg.as_type()), type_helper.get_ci())),) - Pointer<$(&method_return_type)>, Pointer); - typedef $dart_method_type = void Function( - int, $(for arg in &method.arguments() => $(DartCodeOracle::native_dart_type_label(Some(&arg.as_type()), type_helper.get_ci())),) - Pointer<$(&method_return_type)>, Pointer); - }); + let method_return_type = if let Some(ret) = method.return_type() { + DartCodeOracle::native_type_label(Some(ret), type_helper.get_ci()) + } else { + quote!(Void) + }; + + tokens.append(quote! { + typedef $ffi_method_type = Void Function( + Uint64, $(for arg in &arg_native_types => $arg,) + Pointer<$(&method_return_type)>, Pointer); + typedef $dart_method_type = void Function( + int, $(for arg in &arg_dart_types => $arg,) + Pointer<$(&method_return_type)>, Pointer); + }); + } } tokens.append(quote! { @@ -234,51 +315,142 @@ pub fn generate_callback_functions( let _dart_method_type = &format!("UniffiCallbackInterface{callback_name}Method{index}Dart"); // Get parameter types using the oracle - let param_types: Vec = m.arguments().iter().map(|arg| { - let arg_name = DartCodeOracle::var_name(arg.name()); - DartCodeOracle::callback_param_type(&arg.as_type(), &arg_name, type_helper.get_ci()) - }).collect(); + let param_types: Vec = m + .arguments() + .iter() + .map(|arg| { + let arg_name = DartCodeOracle::var_name(arg.name()); + DartCodeOracle::callback_param_type(&arg.as_type(), &arg_name, type_helper.get_ci()) + }) + .collect(); // Get argument lifts using the oracle - let arg_lifts: Vec = m.arguments().iter().enumerate().map(|(arg_idx, arg)| { - let arg_name = DartCodeOracle::var_name(arg.name()); - DartCodeOracle::callback_arg_lift_indexed(&arg.as_type(), &arg_name, arg_idx) - }).collect(); + let arg_lifts: Vec = m + .arguments() + .iter() + .enumerate() + .map(|(arg_idx, arg)| { + let arg_name = DartCodeOracle::var_name(arg.name()); + DartCodeOracle::callback_arg_lift_indexed(&arg.as_type(), &arg_name, arg_idx) + }) + .collect(); // Prepare arg names for the method call using indexes - let arg_names: Vec = m.arguments().iter().enumerate().map(|(arg_idx, arg)| { - DartCodeOracle::callback_arg_name(&arg.as_type(), arg_idx) - }).collect(); - - // Handle return value using the oracle - let call_dart_method = if let Some(ret) = m.return_type() { - DartCodeOracle::callback_return_handling(ret, method_name, arg_names) - } else { - // Handle void return types - DartCodeOracle::callback_void_handling(method_name, arg_names) - }; - - // Get the appropriate out return type - let out_return_type = DartCodeOracle::callback_out_return_type(m.return_type()); + let arg_names: Vec = m + .arguments() + .iter() + .enumerate() + .map(|(arg_idx, arg)| DartCodeOracle::callback_arg_name(&arg.as_type(), arg_idx)) + .collect(); // Generate the function body - let callback_method_name = &format!("{}{}", &DartCodeOracle::fn_name(callback_name), &DartCodeOracle::class_name(m.name())); + let callback_method_name = + &format!("{}{}", &DartCodeOracle::fn_name(callback_name), &DartCodeOracle::class_name(m.name())); + + if m.is_async() { + let completion_base = foreign_future_completion_name(m); + let completion_native = format!("Uniffi{}", completion_base.to_upper_camel_case()); + let completion_pointer = format!("Pointer>", completion_native); + let completion_dart = format!("{completion_native}Dart"); + let result_struct = m.foreign_future_ffi_result_struct(); + let struct_tokens = DartCodeOracle::ffi_struct_name(result_struct.name()); + let struct_tokens_alt = struct_tokens.clone(); + + if let Some(ret) = m.return_type() { + type_helper.include_once_check(&ret.as_codetype().canonical_name(), ret); + } - quote! { - void $callback_method_name(int uniffiHandle, $(for param in ¶m_types => $param,) $out_return_type outReturn, Pointer callStatus) { - final status = callStatus.ref; - try { + let success_return = if let Some(ret) = m.return_type() { + let converter = ret.as_codetype().ffi_converter_name(); + quote!(resultStructPtr.ref.returnValue = $(&converter).lower(result);) + } else { + quote!() + }; + + quote! { + void $callback_method_name( + int uniffiHandle, + $(for param in ¶m_types => $param,) + $(&completion_pointer) uniffiFutureCallback, + int uniffiCallbackData, + Pointer outReturn, + ) { final obj = FfiConverterCallbackInterface$cls_name._handleMap.get(uniffiHandle); $(arg_lifts) - $call_dart_method - } catch (e) { - status.code = CALL_UNEXPECTED_ERROR; - status.errorBuf = FfiConverterString.lower(e.toString()); + final callback = uniffiFutureCallback.asFunction<$(&completion_dart)>(); + final state = _UniffiForeignFutureState(); + final handle = _uniffiForeignFutureHandleMap.insert(state); + outReturn.ref.handle = handle; + outReturn.ref.free = _uniffiForeignFutureFreePointer; + + () async { + try { + final result = await obj.$method_name($(for arg in &arg_names => $arg,)); + final removedState = _uniffiForeignFutureHandleMap.maybeRemove(handle); + final effectiveState = removedState ?? state; + if (effectiveState.cancelled) { + return; + } + effectiveState.cancelled = true; + final resultStructPtr = calloc<$struct_tokens>(); + try { + $success_return + resultStructPtr.ref.callStatus.code = CALL_SUCCESS; + callback(uniffiCallbackData, resultStructPtr.ref); + } finally { + calloc.free(resultStructPtr); + } + } catch (e) { + final removedState = _uniffiForeignFutureHandleMap.maybeRemove(handle); + final effectiveState = removedState ?? state; + if (effectiveState.cancelled) { + return; + } + effectiveState.cancelled = true; + final resultStructPtr = calloc<$struct_tokens_alt>(); + try { + resultStructPtr.ref.callStatus.code = CALL_UNEXPECTED_ERROR; + resultStructPtr.ref.callStatus.errorBuf = + FfiConverterString.lower(e.toString()); + callback(uniffiCallbackData, resultStructPtr.ref); + } finally { + calloc.free(resultStructPtr); + } + } + }(); } + + final Pointer> $(callback_method_name)Pointer = + Pointer.fromFunction<$ffi_method_type>($callback_method_name); } + } else { + // Handle return value using the oracle + let call_dart_method = if let Some(ret) = m.return_type() { + DartCodeOracle::callback_return_handling(ret, method_name, arg_names) + } else { + // Handle void return types + DartCodeOracle::callback_void_handling(method_name, arg_names) + }; + + // Get the appropriate out return type + let out_return_type = DartCodeOracle::callback_out_return_type(m.return_type()); + + quote! { + void $callback_method_name(int uniffiHandle, $(for param in ¶m_types => $param,) $out_return_type outReturn, Pointer callStatus) { + final status = callStatus.ref; + try { + final obj = FfiConverterCallbackInterface$cls_name._handleMap.get(uniffiHandle); + $(arg_lifts) + $call_dart_method + } catch (e) { + status.code = CALL_UNEXPECTED_ERROR; + status.errorBuf = FfiConverterString.lower(e.toString()); + } + } - final Pointer> $(callback_method_name)Pointer = - Pointer.fromFunction<$ffi_method_type>($callback_method_name); + final Pointer> $(callback_method_name)Pointer = + Pointer.fromFunction<$ffi_method_type>($callback_method_name); + } } }).collect(); @@ -303,6 +475,80 @@ pub fn generate_callback_functions( } } +fn generate_foreign_future_struct_definition( + ffi_struct: &FfiStruct, + type_helper: &dyn TypeHelperRenderer, +) -> dart::Tokens { + let struct_name = DartCodeOracle::ffi_struct_name(ffi_struct.name()); + let fields: Vec = ffi_struct + .fields() + .iter() + .map(|field| { + let field_name = DartCodeOracle::var_name(field.name()); + let ffi_field_type = field.type_(); + let field_type = match &ffi_field_type { + FfiType::RustCallStatus => quote!(RustCallStatus), + _ => { + DartCodeOracle::ffi_dart_type_label(Some(&ffi_field_type), type_helper.get_ci()) + } + }; + if let Some(annotation) = foreign_future_field_annotation(&ffi_field_type) { + quote! { + $annotation + external $field_type $field_name; + } + } else { + quote! { + external $field_type $field_name; + } + } + }) + .collect(); + + quote! { + final class $struct_name extends Struct { + $(for field in fields => $field) + } + } +} + +fn generate_foreign_future_completion_typedef( + callback_base: &str, + struct_name: &str, +) -> dart::Tokens { + let native_callback_name = format!("Uniffi{}", callback_base.to_upper_camel_case()); + let dart_callback_name = format!("{native_callback_name}Dart"); + let struct_tokens = DartCodeOracle::ffi_struct_name(struct_name); + let struct_tokens_alt = struct_tokens.clone(); + + let mut tokens = dart::Tokens::new(); + tokens.append(quote!(typedef $native_callback_name = Void Function(Uint64, $struct_tokens);)); + tokens.append(quote!(typedef $dart_callback_name = void Function(int, $struct_tokens_alt);)); + tokens +} + +fn foreign_future_completion_name(method: &Method) -> String { + let return_ffi_type = method.return_type().cloned().map(FfiType::from); + let suffix = FfiType::return_type_name(return_ffi_type.as_ref()).to_upper_camel_case(); + format!("ForeignFutureComplete{suffix}") +} + +fn foreign_future_field_annotation(field_type: &FfiType) -> Option { + match field_type { + FfiType::Int8 => Some(quote!(@Int8())), + FfiType::UInt8 => Some(quote!(@Uint8())), + FfiType::Int16 => Some(quote!(@Int16())), + FfiType::UInt16 => Some(quote!(@Uint16())), + FfiType::Int32 => Some(quote!(@Int32())), + FfiType::UInt32 => Some(quote!(@Uint32())), + FfiType::Int64 => Some(quote!(@Int64())), + FfiType::UInt64 => Some(quote!(@Uint64())), + FfiType::Float32 => Some(quote!(@Float())), + FfiType::Float64 => Some(quote!(@Double())), + _ => None, + } +} + pub fn generate_callback_interface_vtable_init_function( callback_name: &str, methods: &[&Method], diff --git a/src/gen/objects.rs b/src/gen/objects.rs index 27a66f5..1d187c7 100644 --- a/src/gen/objects.rs +++ b/src/gen/objects.rs @@ -114,7 +114,10 @@ pub fn generate_object(obj: &Object, type_helper: &dyn TypeHelperRenderer) -> da quote!() }; - let constructor_definitions = obj.constructors().into_iter().map(|constructor| { + let mut constructor_definitions: Vec = Vec::new(); + let mut async_constructor_factories: Vec = Vec::new(); + + for constructor in obj.constructors() { let ffi_func_name = constructor.ffi_func().name(); let constructor_name = constructor.name(); @@ -126,7 +129,8 @@ pub fn generate_object(obj: &Object, type_helper: &dyn TypeHelperRenderer) -> da // Check if function can throw errors let error_handler = if let Some(error_type) = constructor.throws_type() { - let error_name = DartCodeOracle::class_name(error_type.name().unwrap_or("UnknownError")); + let error_name = + DartCodeOracle::class_name(error_type.name().unwrap_or("UnknownError")); // Use the consistent Exception naming for error handlers let handler_name = format!("{}ErrorHandler", error_name.to_lower_camel_case()); quote!($(handler_name)) @@ -147,18 +151,35 @@ pub fn generate_object(obj: &Object, type_helper: &dyn TypeHelperRenderer) -> da type_helper.include_once_check(&arg.as_codetype().canonical_name(), &arg.as_type()); } - quote! { - // Public constructor - $dart_constructor_decl($dart_params) : _ptr = rustCall((status) => - $lib_instance.$ffi_func_name( - $ffi_call_args status - ), - $error_handler - ) { - _$finalizer_cls_name.attach(this, _ptr, detach: this); - } + if constructor.is_async() { + async_constructor_factories.push(quote! { + static Future<$cls_name> $(DartCodeOracle::fn_name(constructor_name))($dart_params) { + return uniffiRustCallAsync( + () => $lib_instance.$ffi_func_name( + $ffi_call_args + ), + $(DartCodeOracle::async_poll(constructor, type_helper.get_ci())), + $(DartCodeOracle::async_complete(constructor, type_helper.get_ci())), + $(DartCodeOracle::async_free(constructor, type_helper.get_ci())), + (Pointer ptr) => $cls_name._(ptr), + $error_handler, + ); + } + }); + } else { + constructor_definitions.push(quote! { + // Public constructor + $dart_constructor_decl($dart_params) : _ptr = rustCall((status) => + $lib_instance.$ffi_func_name( + $ffi_call_args status + ), + $error_handler + ) { + _$finalizer_cls_name.attach(this, _ptr, detach: this); + } + }); } - }); + } // For interface objects that are used as error types, generate error handlers let is_error_interface = type_helper.get_ci().is_name_used_as_error(obj.name()); @@ -242,6 +263,7 @@ pub fn generate_object(obj: &Object, type_helper: &dyn TypeHelperRenderer) -> da // Public constructors generated from UDL $( for ctor_def in constructor_definitions => $ctor_def ) + $( for factory_def in async_constructor_factories => $factory_def ) // Factory for lifting pointers factory $cls_name.lift(Pointer ptr) { diff --git a/src/gen/oracle.rs b/src/gen/oracle.rs index 5baed9b..403ef16 100644 --- a/src/gen/oracle.rs +++ b/src/gen/oracle.rs @@ -70,7 +70,7 @@ impl DartCodeOracle { } /// Get the idiomatic Dart rendering of an FFI callback function name - fn ffi_callback_name(nm: &str) -> String { + pub fn ffi_callback_name(nm: &str) -> String { format!( "Pointer>", Self::callback_name(&nm.to_upper_camel_case()) diff --git a/src/gen/types.rs b/src/gen/types.rs index e64ac71..6d729c0 100644 --- a/src/gen/types.rs +++ b/src/gen/types.rs @@ -361,6 +361,8 @@ impl Renderer<(FunctionDefinition, dart::Tokens)> for TypeHelpersRenderer<'_> { typedef UniffiRustFutureContinuationCallback = Void Function(Uint64, Int8); + final _uniffiRustFutureContinuationHandles = UniffiHandleMap>(); + Future uniffiRustCallAsync( Pointer Function() rustFutureFunc, void Function(Pointer, Pointer>, Pointer) pollFunc, @@ -371,45 +373,94 @@ impl Renderer<(FunctionDefinition, dart::Tokens)> for TypeHelpersRenderer<'_> { ]) async { final rustFuture = rustFutureFunc(); final completer = Completer(); + final handle = _uniffiRustFutureContinuationHandles.insert(completer); + final callbackData = Pointer.fromAddress(handle); late final NativeCallable callback; - void poll() { + void repoll() { pollFunc( rustFuture, callback.nativeFunction, - Pointer.fromAddress(0), + callbackData, ); } - void onResponse(int _idx, int pollResult) { + + void onResponse(int data, int pollResult) { if (pollResult == UNIFFI_RUST_FUTURE_POLL_READY) { - completer.complete(pollResult); + final readyCompleter = + _uniffiRustFutureContinuationHandles.maybeRemove(data); + if (readyCompleter != null && !readyCompleter.isCompleted) { + readyCompleter.complete(pollResult); + } + } else if (pollResult == UNIFFI_RUST_FUTURE_POLL_MAYBE_READY) { + repoll(); } else { - poll(); + final errorCompleter = + _uniffiRustFutureContinuationHandles.maybeRemove(data); + if (errorCompleter != null && !errorCompleter.isCompleted) { + errorCompleter.completeError( + UniffiInternalError.panicked( + "Unexpected poll result from Rust future: $pollResult", + ), + ); + } } } - callback = NativeCallable.listener(onResponse); + + callback = NativeCallable.listener( + onResponse, + ); try { - poll(); + repoll(); await completer.future; - callback.close(); - final status = calloc(); try { - final result = completeFunc(rustFuture, status); - + checkCallStatus( + errorHandler ?? NullRustCallStatusErrorHandler(), + status, + ); return liftFunc(result); } finally { calloc.free(status); } } finally { + callback.close(); + _uniffiRustFutureContinuationHandles.maybeRemove(handle); freeFunc(rustFuture); } } + typedef UniffiForeignFutureFree = Void Function(Uint64); + typedef UniffiForeignFutureFreeDart = void Function(int); + + class _UniffiForeignFutureState { + bool cancelled = false; + } + + final _uniffiForeignFutureHandleMap = UniffiHandleMap<_UniffiForeignFutureState>(); + + void _uniffiForeignFutureFree(int handle) { + final state = _uniffiForeignFutureHandleMap.maybeRemove(handle); + if (state != null) { + state.cancelled = true; + } + } + + final Pointer> + _uniffiForeignFutureFreePointer = + Pointer.fromFunction(_uniffiForeignFutureFree); + + final class UniffiForeignFuture extends Struct { + @Uint64() + external int handle; + + external Pointer> free; + } + class UniffiHandleMap { final Map _map = {}; int _counter = 0; @@ -430,11 +481,15 @@ impl Renderer<(FunctionDefinition, dart::Tokens)> for TypeHelpersRenderer<'_> { } void remove(int handle) { - if (_map.remove(handle) == null) { + if (maybeRemove(handle) == null) { throw UniffiInternalError( UniffiInternalError.unexpectedStaleHandle, "Handle not found"); } } + + T? maybeRemove(int handle) { + return _map.remove(handle); + } } };