diff --git a/Cargo.lock b/Cargo.lock index 318db752e7..c654f11b8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5845,6 +5845,7 @@ dependencies = [ "dioxus-history", "dioxus-hooks", "dioxus-signals", + "dioxus-stores", "futures-channel", "futures-util", "generational-box", @@ -5888,6 +5889,7 @@ dependencies = [ "dioxus", "dioxus-core", "dioxus-signals", + "dioxus-stores", "futures-channel", "futures-util", "generational-box", diff --git a/examples/01-app-demos/hotdog/src/frontend.rs b/examples/01-app-demos/hotdog/src/frontend.rs index 0e3e0d2cb9..63e31600fe 100644 --- a/examples/01-app-demos/hotdog/src/frontend.rs +++ b/examples/01-app-demos/hotdog/src/frontend.rs @@ -7,7 +7,7 @@ use crate::{ #[component] pub fn Favorites() -> Element { - let mut favorites = use_loader(list_dogs)?; + let favorites = use_loader(list_dogs)?; rsx! { div { id: "favorites", @@ -41,7 +41,7 @@ pub fn NavBar() -> Element { #[component] pub fn DogView() -> Element { - let mut img_src = use_loader(|| async move { + let img_src = use_loader(|| async move { anyhow::Ok( reqwest::get("https://dog.ceo/api/breeds/image/random") .await? diff --git a/examples/01-app-demos/weather_app.rs b/examples/01-app-demos/weather_app.rs index 2d0069bbc9..63a3102853 100644 --- a/examples/01-app-demos/weather_app.rs +++ b/examples/01-app-demos/weather_app.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] -use dioxus::{fullstack::Loading, prelude::*}; +use dioxus::prelude::*; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -36,17 +36,17 @@ fn app() -> Element { } Forecast { weather: weather.cloned() } div { height: "20px", margin_top: "10px", - if weather.loading() { + if weather.pending() { "Fetching weather data..." } } }, - Err(Loading::Pending(_)) => rsx! { + Err(RenderError::Error(_)) => rsx! { + div { "Failed to load weather data." } + }, + Err(RenderError::Suspended(_)) => rsx! { div { "Loading weather data..." } }, - Err(Loading::Failed(_)) => rsx! { - div { "Failed to load weather data." } - } } } } @@ -162,16 +162,16 @@ fn SearchBox(mut country: WriteSignal) -> Element { } } }, - Err(Loading::Pending(_)) => rsx! { + Err(RenderError::Error(error)) => rsx! { li { class: "pl-8 pr-2 py-1 border-b-2 border-gray-100 relative", - "Searching..." + "Failed to search: {error:?}" } }, - Err(Loading::Failed(handle)) => rsx! { + Err(RenderError::Suspended(_)) => rsx! { li { class: "pl-8 pr-2 py-1 border-b-2 border-gray-100 relative", - "Failed to search: {handle.error():?}" + "Searching..." } - } + }, } } } diff --git a/examples/05-using-async/suspense.rs b/examples/05-using-async/suspense.rs index cf8811bbee..876681e6a8 100644 --- a/examples/05-using-async/suspense.rs +++ b/examples/05-using-async/suspense.rs @@ -36,7 +36,7 @@ fn Doggo() -> Element { // // During SSR, `use_loader` will serialize the contents of the fetch, and during hydration, the client will // use the pre-fetched data instead of re-fetching to render. - let mut dog = use_loader(move || async move { + let dog = use_loader(move || async move { #[derive(serde::Deserialize, serde::Serialize, PartialEq)] struct DogApi { message: String, diff --git a/examples/08-apis/eval.rs b/examples/08-apis/eval.rs index d766123f0c..9106d3931c 100644 --- a/examples/08-apis/eval.rs +++ b/examples/08-apis/eval.rs @@ -40,7 +40,7 @@ fn app() -> Element { res }); - match future.value().as_ref() { + match future.as_ref() { Some(v) => rsx!( p { "{v}" } ), _ => rsx!( p { "waiting.." } ), } diff --git a/packages/fullstack-core/Cargo.toml b/packages/fullstack-core/Cargo.toml index 9fc6fae631..d708ea0a6c 100644 --- a/packages/fullstack-core/Cargo.toml +++ b/packages/fullstack-core/Cargo.toml @@ -29,6 +29,7 @@ inventory = { workspace = true } serde_json = { workspace = true } generational-box = { workspace = true } futures-util = { workspace = true, features = ["std"] } +dioxus-stores.workspace = true [features] web = [] diff --git a/packages/fullstack-core/src/loader.rs b/packages/fullstack-core/src/loader.rs index dccb808786..8a3711c381 100644 --- a/packages/fullstack-core/src/loader.rs +++ b/packages/fullstack-core/src/loader.rs @@ -1,15 +1,12 @@ -use dioxus_core::{use_hook, IntoAttributeValue, IntoDynNode, Subscribers}; -use dioxus_core::{CapturedError, RenderError, Result, SuspendedFuture}; -use dioxus_hooks::{use_resource, use_signal, Resource}; -use dioxus_signals::{ - read_impls, ReadSignal, Readable, ReadableBoxExt, ReadableExt, ReadableRef, Signal, Writable, - WritableExt, WriteLock, -}; -use generational_box::{BorrowResult, UnsyncStorage}; +use dioxus_core::{CapturedError, RenderError, Result}; +use dioxus_hooks::Resource; +use dioxus_signals::{MappedMutSignal, WriteSignal}; +use dioxus_stores::MappedStore; use serde::{de::DeserializeOwned, Serialize}; -use std::ops::Deref; use std::{cmp::PartialEq, future::Future}; +use crate::use_server_future; + /// A hook to create a resource that loads data asynchronously. /// /// This hook takes a closure that returns a future. This future will be executed on both the client @@ -29,352 +26,29 @@ use std::{cmp::PartialEq, future::Future}; /// as the component that called `use_loader`. #[allow(clippy::result_large_err)] #[track_caller] -pub fn use_loader(mut future: impl FnMut() -> F + 'static) -> Result, Loading> +pub fn use_loader( + mut future: impl FnMut() -> F + 'static, +) -> Result< + Resource< + MappedStore< + T, + MappedMutSignal< + Result, + WriteSignal>>, + >, + >, + >, + RenderError, +> where F: Future> + 'static, T: 'static + PartialEq + Serialize + DeserializeOwned, E: Into + 'static, { - let serialize_context = use_hook(crate::transport::serialize_context); - - // We always create a storage entry, even if the data isn't ready yet to make it possible to deserialize pending server futures on the client - #[allow(unused)] - let storage_entry: crate::transport::SerializeContextEntry> = - use_hook(|| serialize_context.create_entry()); - - #[cfg(feature = "server")] - let caller = std::panic::Location::caller(); - - // If this is the first run and we are on the web client, the data might be cached - #[cfg(feature = "web")] - let initial_web_result = - use_hook(|| std::rc::Rc::new(std::cell::RefCell::new(Some(storage_entry.get())))); - - let mut error = use_signal(|| None as Option); - let mut value = use_signal(|| None as Option); - let mut loader_state = use_signal(|| LoaderState::Pending); - - let resource = use_resource(move || { - #[cfg(feature = "server")] - let storage_entry = storage_entry.clone(); - - let user_fut = future(); - - #[cfg(feature = "web")] - let initial_web_result = initial_web_result.clone(); - - #[allow(clippy::let_and_return)] - async move { - // If this is the first run and we are on the web client, the data might be cached - #[cfg(feature = "web")] - match initial_web_result.take() { - // The data was deserialized successfully from the server - Some(Ok(o)) => { - match o { - Ok(v) => { - value.set(Some(v)); - loader_state.set(LoaderState::Ready); - } - Err(e) => { - error.set(Some(e)); - loader_state.set(LoaderState::Failed); - } - }; - return; - } - - // The data is still pending from the server. Don't try to resolve it on the client - Some(Err(crate::transport::TakeDataError::DataPending)) => { - std::future::pending::<()>().await - } - - // The data was not available on the server, rerun the future - Some(Err(_)) => {} - - // This isn't the first run, so we don't need do anything - None => {} - } - - // Otherwise just run the future itself - let out = user_fut.await; - - // Remap the error to the captured error type so it's cheap to clone and pass out, just - // slightly more cumbersome to access the inner error. - let out = out.map_err(|e| { - let anyhow_err: anyhow::Error = e.into(); - anyhow_err.into() - }); - - // If this is the first run and we are on the server, cache the data in the slot we reserved for it - #[cfg(feature = "server")] - storage_entry.insert(&out, caller); - - match out { - Ok(v) => { - value.set(Some(v)); - loader_state.set(LoaderState::Ready); - } - Err(e) => { - error.set(Some(e)); - loader_state.set(LoaderState::Failed); - } - }; - } - }); - - // On the first run, force this task to be polled right away in case its value is ready - use_hook(|| { - let _ = resource.task().poll_now(); - }); - - let read_value = use_hook(|| value.map(|f| f.as_ref().unwrap()).boxed()); - - let handle = LoaderHandle { - resource, - error, - state: loader_state, - _marker: std::marker::PhantomData, - }; - - match &*loader_state.read_unchecked() { - LoaderState::Pending => Err(Loading::Pending(handle)), - LoaderState::Failed => Err(Loading::Failed(handle)), - LoaderState::Ready => Ok(Loader { - real_value: value, - read_value, - error, - state: loader_state, - handle, - }), - } -} - -/// A Loader is a signal that represents a value that is loaded asynchronously. -/// -/// Once a `Loader` has been successfully created from `use_loader`, it can be use like a normal signal of type `T`. -/// -/// When the loader is re-reloading its values, it will no longer suspend its component, making it -/// very useful for server-side-rendering. -pub struct Loader { - /// This is a signal that unwraps the inner value. We can't give it out unless we know the inner value is Some(T)! - read_value: ReadSignal, - - /// This is the actual signal. We let the user set this value if they want to, but we can't let them set it to `None`. - real_value: Signal>, - error: Signal>, - state: Signal, - handle: LoaderHandle, -} - -impl Loader { - /// Get the error that occurred during loading, if any. - /// - /// After initial load, this will return `None` until the next reload fails. - pub fn error(&self) -> Option { - self.error.read().as_ref().cloned() - } - - /// Restart the loading task. - /// - /// After initial load, this won't suspend the component, but will reload in the background. - pub fn restart(&mut self) { - self.handle.restart(); - } - - /// Check if the loader has failed. - pub fn is_error(&self) -> bool { - self.error.read().is_some() && matches!(*self.state.read(), LoaderState::Failed) - } - - /// Cancel the current loading task. - pub fn cancel(&mut self) { - self.handle.resource.cancel(); - } - - pub fn loading(&self) -> bool { - !self.handle.resource.finished() - } -} - -impl Writable for Loader { - type WriteMetadata = > as Writable>::WriteMetadata; - - fn try_write_unchecked( - &self, - ) -> std::result::Result< - dioxus_signals::WritableRef<'static, Self>, - generational_box::BorrowMutError, - > - where - Self::Target: 'static, - { - let writer = self.real_value.try_write_unchecked()?; - Ok(WriteLock::map(writer, |f: &mut Option| { - f.as_mut() - .expect("Loader value should be set if the `Loader` exists") - })) - } -} - -impl Readable for Loader { - type Target = T; - type Storage = UnsyncStorage; - - #[track_caller] - fn try_read_unchecked( - &self, - ) -> Result, generational_box::BorrowError> - where - T: 'static, - { - Ok(self.read_value.read_unchecked()) - } - - /// Get the current value of the signal. **Unlike read, this will not subscribe the current scope to the signal which can cause parts of your UI to not update.** - /// - /// If the signal has been dropped, this will panic. - #[track_caller] - fn try_peek_unchecked(&self) -> BorrowResult> - where - T: 'static, - { - Ok(self.peek_unchecked()) - } - - fn subscribers(&self) -> Subscribers - where - T: 'static, - { - self.read_value.subscribers() - } -} - -impl IntoAttributeValue for Loader -where - T: Clone + IntoAttributeValue + PartialEq + 'static, -{ - fn into_value(self) -> dioxus_core::AttributeValue { - self.with(|f| f.clone().into_value()) - } -} - -impl IntoDynNode for Loader -where - T: Clone + IntoDynNode + PartialEq + 'static, -{ - fn into_dyn_node(self) -> dioxus_core::DynamicNode { - self().into_dyn_node() - } -} - -impl PartialEq for Loader { - fn eq(&self, other: &Self) -> bool { - self.read_value == other.read_value - } -} - -impl Deref for Loader -where - T: PartialEq + 'static, -{ - type Target = dyn Fn() -> T; - - fn deref(&self) -> &Self::Target { - unsafe { ReadableExt::deref_impl(self) } - } -} - -read_impls!(Loader where T: PartialEq); - -impl Clone for Loader { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for Loader {} - -#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug)] -pub enum LoaderState { - /// The loader's future is still running - Pending, - - /// The loader's future has completed successfully - Ready, - - /// The loader's future has failed and now the loader is in an error state. - Failed, -} - -#[derive(PartialEq)] -pub struct LoaderHandle { - resource: Resource<()>, - error: Signal>, - state: Signal, - _marker: std::marker::PhantomData, -} - -impl LoaderHandle { - /// Restart the loading task. - pub fn restart(&mut self) { - self.resource.restart(); - } - - /// Get the current state of the loader. - pub fn state(&self) -> LoaderState { - *self.state.read() - } - - pub fn error(&self) -> Option { - self.error.read().as_ref().cloned() - } -} - -impl Clone for LoaderHandle { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for LoaderHandle {} - -#[derive(PartialEq)] -pub enum Loading { - /// The loader is still pending and the component should suspend. - Pending(LoaderHandle), - - /// The loader has failed and an error will be returned up the tree. - Failed(LoaderHandle), -} - -impl std::fmt::Debug for Loading { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Loading::Pending(_) => write!(f, "Loading::Pending"), - Loading::Failed(_) => write!(f, "Loading::Failed"), - } - } -} - -impl std::fmt::Display for Loading { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Loading::Pending(_) => write!(f, "Loading is still pending"), - Loading::Failed(_) => write!(f, "Loading has failed"), - } - } -} - -/// Convert a Loading into a RenderError for use with the `?` operator in components -impl From for RenderError { - fn from(val: Loading) -> Self { - match val { - Loading::Pending(t) => RenderError::Suspended(SuspendedFuture::new(t.resource.task())), - Loading::Failed(err) => RenderError::Error( - err.error - .cloned() - .expect("LoaderHandle in Failed state should always have an error"), - ), - } - } + let resolved = use_server_future(move || { + let fut = future(); + async move { fut.await.map_err(|e| CapturedError::from(e.into())) } + })?; + let ok = resolved.transpose().map_err(|e| RenderError::Error(e()))?; + Ok(ok) } diff --git a/packages/fullstack-core/src/server_future.rs b/packages/fullstack-core/src/server_future.rs index 9cbf14b380..69c5dc369d 100644 --- a/packages/fullstack-core/src/server_future.rs +++ b/packages/fullstack-core/src/server_future.rs @@ -1,7 +1,6 @@ use crate::Transportable; -use dioxus_core::{suspend, use_hook, RenderError}; +use dioxus_core::{use_hook, RenderError}; use dioxus_hooks::*; -use dioxus_signals::ReadableExt; use std::future::Future; /// Runs a future with a manual list of dependencies and returns a resource with the result if the future is finished or a suspended error if it is still running. @@ -61,7 +60,7 @@ use std::future::Future; #[track_caller] pub fn use_server_future( mut future: impl FnMut() -> F + 'static, -) -> Result, RenderError> +) -> Result, RenderError> where F: Future + 'static, T: Transportable, @@ -128,12 +127,5 @@ where }); // Suspend if the value isn't ready - if resource.state().cloned() == UseResourceState::Pending { - let task = resource.task(); - if !task.paused() { - return Err(suspend(task).unwrap_err()); - } - } - - Ok(resource) + resource.suspend() } diff --git a/packages/fullstack/src/payloads/websocket.rs b/packages/fullstack/src/payloads/websocket.rs index b0928b5be1..0c39693ee7 100644 --- a/packages/fullstack/src/payloads/websocket.rs +++ b/packages/fullstack/src/payloads/websocket.rs @@ -30,7 +30,7 @@ use axum_core::response::{IntoResponse, Response}; use bytes::Bytes; use dioxus_core::{use_hook, CapturedError, Result}; use dioxus_fullstack_core::{HttpError, RequestError}; -use dioxus_hooks::{use_resource, Resource, UseWaker}; +use dioxus_hooks::{use_resource, PendingResource, Resource, UseWaker}; use dioxus_hooks::{use_signal, use_waker}; use dioxus_signals::{ReadSignal, ReadableExt, ReadableOptionExt, Signal, WritableExt}; use futures::StreamExt; @@ -109,7 +109,7 @@ where Out: 'static, Enc: 'static, { - connection: Resource, CapturedError>>, + connection: PendingResource, CapturedError>>, waker: UseWaker<()>, status: Signal, status_read: ReadSignal, @@ -120,7 +120,7 @@ impl UseWebsocket { /// `.try_recv()` will not fail due to the connection not being ready. pub async fn connect(&self) -> WebsocketState { // Wait for the connection to be established - while !self.connection.finished() { + while self.connection.running() { _ = self.waker.wait().await; } diff --git a/packages/hooks/Cargo.toml b/packages/hooks/Cargo.toml index 747e4ad366..f0b3e261fc 100644 --- a/packages/hooks/Cargo.toml +++ b/packages/hooks/Cargo.toml @@ -23,6 +23,7 @@ futures-util = { workspace = true, features = ["std"] } generational-box = { workspace = true } rustversion = { workspace = true } warnings = { workspace = true } +dioxus-stores.workspace = true [dev-dependencies] futures-util = { workspace = true, default-features = false } diff --git a/packages/hooks/src/use_future.rs b/packages/hooks/src/use_future.rs index 74431a2bb7..27235e204f 100644 --- a/packages/hooks/src/use_future.rs +++ b/packages/hooks/src/use_future.rs @@ -1,4 +1,3 @@ -#![allow(missing_docs)] use crate::{use_callback, use_hook_did_run, use_signal}; use dioxus_core::{use_hook, Callback, Subscribers, Task}; use dioxus_signals::*; diff --git a/packages/hooks/src/use_resource.rs b/packages/hooks/src/use_resource.rs index b9f7abf9bb..62c8e14fe6 100644 --- a/packages/hooks/src/use_resource.rs +++ b/packages/hooks/src/use_resource.rs @@ -1,100 +1,71 @@ -#![allow(missing_docs)] - -use crate::{use_callback, use_signal, use_waker, UseWaker}; - -use dioxus_core::{ - spawn, use_hook, Callback, IntoAttributeValue, IntoDynNode, ReactiveContext, RenderError, - Subscribers, SuspendedFuture, Task, +use std::{ + fmt::{Debug, Display}, + future::Future, + ops::Deref, }; -use dioxus_signals::*; -use futures_util::{ - future::{self}, - pin_mut, FutureExt, StreamExt, + +use dioxus_core::{spawn, use_hook, ReactiveContext, RenderError, SuspendedFuture, Task}; +use dioxus_signals::{ + BorrowError, CopyValue, Global, InitializeFromFunction, MappedMutSignal, Readable, ReadableExt, + ReadableRef, WritableExt, WriteSignal, }; -use std::{cell::Cell, future::Future, rc::Rc}; -use std::{fmt::Debug, ops::Deref}; - -#[doc = include_str!("../docs/use_resource.md")] -#[doc = include_str!("../docs/rules_of_hooks.md")] -#[doc = include_str!("../docs/moving_state_around.md")] -#[doc(alias = "use_async_memo")] -#[doc(alias = "use_memo_async")] +use dioxus_stores::{MappedStore, Store}; +use futures_util::{pin_mut, FutureExt, StreamExt}; + #[track_caller] -pub fn use_resource(mut future: impl FnMut() -> F + 'static) -> Resource +pub fn use_resource(future: impl FnMut() -> F + 'static) -> PendingResource where T: 'static, F: Future + 'static, { let location = std::panic::Location::caller(); + use_hook(|| Resource::new_with_location(future, location)) +} - let mut value = use_signal(|| None); - let mut state = use_signal(|| UseResourceState::Pending); - let (rc, changed) = use_hook(|| { - let (rc, changed) = ReactiveContext::new_with_origin(location); - (rc, Rc::new(Cell::new(Some(changed)))) - }); - - let mut waker = use_waker::<()>(); - - let cb = use_callback(move |_| { - // Set the state to Pending when the task is restarted - state.set(UseResourceState::Pending); - - // Create the user's task - let fut = rc.reset_and_run_in(&mut future); - - // Spawn a wrapper task that polls the inner future and watches its dependencies - spawn(async move { - // Move the future here and pin it so we can poll it - let fut = fut; - pin_mut!(fut); - - // Run each poll in the context of the reactive scope - // This ensures the scope is properly subscribed to the future's dependencies - let res = future::poll_fn(|cx| { - rc.run_in(|| { - tracing::trace_span!("polling resource", location = %location) - .in_scope(|| fut.poll_unpin(cx)) - }) +fn run_future_in_context( + rc: &ReactiveContext, + mut future: impl FnMut() -> F, + location: &'static std::panic::Location<'static>, +) -> Task +where + T: 'static, + F: Future + 'static, +{ + let rc = rc.clone(); + // Create the user's task + let fut = rc.reset_and_run_in(&mut future); + + // Spawn a wrapper task that polls the inner future and watches its dependencies + spawn(async move { + // Move the future here and pin it so we can poll it + let fut = fut; + pin_mut!(fut); + + // Run each poll in the context of the reactive scope + // This ensures the scope is properly subscribed to the future's dependencies + std::future::poll_fn(|cx| { + rc.run_in(|| { + tracing::trace_span!("polling resource", location = %location) + .in_scope(|| fut.poll_unpin(cx)) }) - .await; - - // Set the value and state - state.set(UseResourceState::Ready); - value.set(Some(res)); - - // Notify that the value has changed - waker.wake(()); }) - }); - - let mut task = use_hook(|| Signal::new(cb(()))); - - use_hook(|| { - let mut changed = changed.take().unwrap(); - spawn(async move { - loop { - // Wait for the dependencies to change - let _ = changed.next().await; - - // Stop the old task - task.write().cancel(); + .await; + }) +} - // Start a new task - task.set(cb(())); - } - }) - }); - - Resource { - task, - value, - state, - waker, - callback: cb, - } +struct ResourceHandle { + task: Task, + rc: ReactiveContext, + wakers: Vec, } +pub type PendingResource = Resource>>; +pub type ResolvedResource>> = Resource>; +pub type OkResource>>> = + Resource, Lens>>>; +pub type ErrResource>>> = + Resource, Lens>>>; + /// A handle to a reactive future spawned with [`use_resource`] that can be used to modify or read the result of the future. /// /// ## Example @@ -120,49 +91,33 @@ where /// } /// } /// ``` -#[derive(Debug)] -pub struct Resource { - waker: UseWaker<()>, - value: Signal>, - task: Signal, - state: Signal, - callback: Callback<(), Task>, -} - -impl PartialEq for Resource { - fn eq(&self, other: &Self) -> bool { - self.value == other.value - && self.state == other.state - && self.task == other.task - && self.callback == other.callback - } +pub struct Resource { + state: S, + handle: CopyValue, } -impl Clone for Resource { +impl Clone for Resource +where + S: Clone, +{ fn clone(&self) -> Self { - *self + Resource { + state: self.state.clone(), + handle: self.handle, + } } } -impl Copy for Resource {} - -/// A signal that represents the state of the resource -// we might add more states (panicked, etc) -#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug)] -pub enum UseResourceState { - /// The resource's future is still running - Pending, - /// The resource's future has been forcefully stopped - Stopped, +impl Copy for Resource where S: Copy {} - /// The resource's future has been paused, tempoarily - Paused, - - /// The resource's future has completed - Ready, -} +impl Resource { + fn replace_state(self) -> impl FnOnce(S2) -> Resource { + move |new_state| Resource { + state: new_state, + handle: self.handle, + } + } -impl Resource { /// Restart the resource's future. /// /// This will cancel the current future and start a new one. @@ -188,10 +143,8 @@ impl Resource { /// } /// } /// ``` - pub fn restart(&mut self) { - self.task.write().cancel(); - let new_task = self.callback.call(()); - self.task.set(new_task); + pub fn restart(&self) { + self.handle.read().rc.mark_dirty(); } /// Forcefully cancel the resource's future. @@ -216,9 +169,8 @@ impl Resource { /// } /// } /// ``` - pub fn cancel(&mut self) { - self.state.set(UseResourceState::Stopped); - self.task.write().cancel(); + pub fn cancel(&self) { + self.handle.read().task.cancel(); } /// Pause the resource's future. @@ -249,9 +201,8 @@ impl Resource { /// } /// } /// ``` - pub fn pause(&mut self) { - self.state.set(UseResourceState::Paused); - self.task.write().pause(); + pub fn pause(&self) { + self.handle.read().task.pause(); } /// Resume the resource's future. @@ -282,51 +233,14 @@ impl Resource { /// } /// } /// ``` - pub fn resume(&mut self) { - if self.finished() { - return; - } - - self.state.set(UseResourceState::Pending); - self.task.write().resume(); - } - - /// Clear the resource's value. This will just reset the value. It will not modify any running tasks. - /// - /// ## Example - /// ```rust, no_run - /// # use dioxus::prelude::*; - /// # use std::time::Duration; - /// fn App() -> Element { - /// let mut revision = use_signal(|| "1d03b42"); - /// let mut resource = use_resource(move || async move { - /// // This will run every time the revision signal changes because we read the count inside the future - /// reqwest::get(format!("https://github.com/DioxusLabs/awesome-dioxus/blob/{revision}/awesome.json")).await - /// }); - /// - /// rsx! { - /// button { - /// // We clear the value without modifying any running tasks with the `clear` method - /// onclick: move |_| resource.clear(), - /// "Clear" - /// } - /// "{resource:?}" - /// } - /// } - /// ``` - pub fn clear(&mut self) { - self.value.write().take(); + pub fn resume(&self) { + self.handle.read().task.resume(); } /// Get a handle to the inner task backing this resource /// Modify the task through this handle will cause inconsistent state pub fn task(&self) -> Task { - self.task.cloned() - } - - /// Is the resource's future currently running? - pub fn pending(&self) -> bool { - matches!(*self.state.peek(), UseResourceState::Pending) + self.handle.read().task } /// Is the resource's future currently finished running? @@ -357,13 +271,12 @@ impl Resource { /// } /// ``` pub fn finished(&self) -> bool { - matches!( - *self.state.peek(), - UseResourceState::Ready | UseResourceState::Stopped - ) + !self.handle.read().task.paused() } - /// Get the current state of the resource's future. This method returns a [`ReadSignal`] which can be read to get the current state of the resource or passed to other hooks and components. + /// Is the resource's future currently running? + /// + /// Reading this does not subscribe to the future's state /// /// ## Example /// ```rust, no_run @@ -376,25 +289,20 @@ impl Resource { /// reqwest::get(format!("https://github.com/DioxusLabs/awesome-dioxus/blob/{revision}/awesome.json")).await /// }); /// - /// // We can read the current state of the future with the `state` method - /// match resource.state().cloned() { - /// UseResourceState::Pending => rsx! { - /// "The resource is still pending" - /// }, - /// UseResourceState::Paused => rsx! { - /// "The resource has been paused" - /// }, - /// UseResourceState::Stopped => rsx! { - /// "The resource has been stopped" - /// }, - /// UseResourceState::Ready => rsx! { - /// "The resource is ready!" - /// }, + /// // We can use the `finished` method to check if the future is finished + /// if resource.pending() { + /// rsx! { + /// "The resource is still running" + /// } + /// } else { + /// rsx! { + /// "The resource is finished" + /// } /// } /// } /// ``` - pub fn state(&self) -> ReadSignal { - self.state.into() + pub fn pending(&self) -> bool { + self.handle.read().task.paused() } /// Get the current value of the resource's future. This method returns a [`ReadSignal`] which can be read to get the current value of the resource or passed to other hooks and components. @@ -424,136 +332,293 @@ impl Resource { /// } /// } /// ``` - pub fn value(&self) -> ReadSignal> { - self.value.into() + pub fn value(&self) -> S + where + S: Clone, + { + self.state.clone() } +} - /// Suspend the resource's future and only continue rendering when the future is ready - pub fn suspend(&self) -> std::result::Result>>, RenderError> { - match self.state.cloned() { - UseResourceState::Stopped | UseResourceState::Paused | UseResourceState::Pending => { - let task = self.task(); - if task.paused() { - Ok(self.value.map(|v| v.as_ref().unwrap())) - } else { - Err(RenderError::Suspended(SuspendedFuture::new(task))) - } +impl PendingResource { + #[track_caller] + pub fn new(future: impl FnMut() -> F + 'static) -> Self + where + F: Future + 'static, + { + let location = std::panic::Location::caller(); + Self::new_with_location(future, location) + } + + pub fn new_with_location( + mut future: impl FnMut() -> F + 'static, + location: &'static std::panic::Location<'static>, + ) -> Self + where + F: Future + 'static, + { + let mut state = Store::new(None); + let mut future = move || { + let fut = future(); + async move { + let result = fut.await; + state.set(Some(result)); } - _ => Ok(self.value.map(|v| v.as_ref().unwrap())), - } + }; + let (rc, mut changed) = ReactiveContext::new(); + + // Start the initial task + let mut task = run_future_in_context(&rc, &mut future, location); + let handle = ResourceHandle { + task: task.clone(), + wakers: Vec::new(), + rc: rc.clone(), + }; + let mut handle = CopyValue::new(handle); + + // Spawn a task to watch for changes + spawn(async move { + loop { + // Wait for the dependencies to change + let _ = changed.next().await; + + // Stop the old task + task.cancel(); + + // Start a new task + task = run_future_in_context( + &rc, + &mut || { + let future = future(); + async move { + let result = future.await; + let wakers = std::mem::take(&mut handle.write().wakers); + for waker in wakers { + waker.wake(); + } + result + } + }, + location, + ); + let mut handle = handle.write(); + handle.task = task.clone(); + } + }); + Resource { state, handle } + } + + /// Clear the resource's value. This will just reset the value. It will not modify any running tasks. + /// + /// ## Example + /// ```rust, no_run + /// # use dioxus::prelude::*; + /// # use std::time::Duration; + /// fn App() -> Element { + /// let mut revision = use_signal(|| "1d03b42"); + /// let mut resource = use_resource(move || async move { + /// // This will run every time the revision signal changes because we read the count inside the future + /// reqwest::get(format!("https://github.com/DioxusLabs/awesome-dioxus/blob/{revision}/awesome.json")).await + /// }); + /// + /// rsx! { + /// button { + /// // We clear the value without modifying any running tasks with the `clear` method + /// onclick: move |_| resource.clear(), + /// "Clear" + /// } + /// "{resource:?}" + /// } + /// } + /// ``` + pub fn clear(&mut self) { + self.state.set(None); } } -impl Resource> { - /// Convert the `Resource>` into an `Option, MappedSignal>>` - #[allow(clippy::type_complexity)] - pub fn result( - &self, - ) -> Option< - Result< - MappedSignal>>>, - MappedSignal>>>, - >, - > { - let value: MappedSignal>>> = self.value.map(|v| match v { - Some(Ok(ref res)) => res, - _ => panic!("Resource is not ready"), - }); +impl + 'static> Deref for Resource { + type Target = dyn Fn() -> S::Target; - let error: MappedSignal>>> = self.value.map(|v| match v { - Some(Err(ref err)) => err, - _ => panic!("Resource is not ready"), - }); + fn deref(&self) -> &Self::Target { + unsafe { ReadableExt::deref_impl(&self.state) } + } +} - match &*self.value.peek() { - Some(Ok(_)) => Some(Ok(value)), - Some(Err(_)) => Some(Err(error)), - None => None, - } +impl Display for Resource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.state.fmt(f) } } -impl From> for ReadSignal> { - fn from(val: Resource) -> Self { - val.value.into() +impl Debug for Resource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.state.fmt(f) } } -impl Readable for Resource { - type Target = Option; - type Storage = UnsyncStorage; +impl Readable for Resource { + type Target = S::Target; + type Storage = S::Storage; - #[track_caller] - fn try_read_unchecked( - &self, - ) -> Result, generational_box::BorrowError> { - self.value.try_read_unchecked() + fn try_read_unchecked(&self) -> std::result::Result, BorrowError> + where + Self::Target: 'static, + { + self.state.try_read_unchecked() } - #[track_caller] - fn try_peek_unchecked( - &self, - ) -> Result, generational_box::BorrowError> { - self.value.try_peek_unchecked() + fn try_peek_unchecked(&self) -> std::result::Result, BorrowError> + where + Self::Target: 'static, + { + self.state.try_peek_unchecked() } - fn subscribers(&self) -> Subscribers { - self.value.subscribers() + fn subscribers(&self) -> dioxus_core::Subscribers + where + Self::Target: 'static, + { + self.state.subscribers() } } -impl Writable for Resource { - type WriteMetadata = > as Writable>::WriteMetadata; +impl Resource>, Lens>> +where + Lens: Readable>> + Copy + 'static, +{ + /// Convert the `Resource>` into an `Option, ErrResource>>` + pub fn result(&self) -> Option, ErrResource>> { + self.transpose() + .map(|store_transposed| store_transposed.transpose()) + } - fn try_write_unchecked( - &self, - ) -> Result, generational_box::BorrowMutError> + pub fn ready(&self) -> Result, RenderError> where - Self::Target: 'static, + E: Clone + Into, + Lens: 'static, + { + self.suspend()? + .transpose() + .map_err(|err_store| err_store().into()) + } + + pub fn ok(&self) -> Result>, RenderError> + where + E: Clone + Into, + Lens: 'static, { - self.value.try_write_unchecked() + match self.result() { + None => Ok(None), + Some(Ok(ok_store)) => Ok(Some(ok_store)), + Some(Err(err_store)) => Err(err_store().into()), + } } } -impl IntoAttributeValue for Resource +impl Resource, Lens>> where - T: Clone + IntoAttributeValue, + Lens: Readable> + Copy + 'static, { - fn into_value(self) -> dioxus_core::AttributeValue { - self.with(|f| f.clone().into_value()) + /// Is the resource's value ready? + pub fn resolved(&self) -> bool { + self.state.is_some() + } + + /// Is the resource's value currently running? + pub fn running(&self) -> bool { + self.state.is_none() + } + + /// Suspend the resource's future and only continue rendering when the future is ready + pub fn suspend(&self) -> Result, RenderError> { + self.transpose() + .ok_or_else(|| RenderError::Suspended(SuspendedFuture::new(self.handle.read().task))) + } + + pub fn transpose(&self) -> Option> { + self.state.transpose().map(self.replace_state()) } } -impl IntoDynNode for Resource +impl Resource, Lens>> where - T: Clone + IntoDynNode, + Lens: Readable> + Copy + 'static, { - fn into_dyn_node(self) -> dioxus_core::DynamicNode { - self().into_dyn_node() + pub fn transpose( + &self, + ) -> Result>, Resource>> { + self.state + .transpose() + .map(self.replace_state()) + .map_err(self.replace_state()) } } -/// Allow calling a signal with signal() syntax -/// -/// Currently only limited to copy types, though could probably specialize for string/arc/rc -impl Deref for Resource { - type Target = dyn Fn() -> Option; +impl std::future::IntoFuture for Resource { + type Output = (); + type IntoFuture = ResourceFuture; - fn deref(&self) -> &Self::Target { - unsafe { ReadableExt::deref_impl(self) } + fn into_future(self) -> Self::IntoFuture { + ResourceFuture { + resource: self.handle, + } } } -impl std::future::Future for Resource { +/// A future that is awaiting the resolution of a resource +pub struct ResourceFuture { + resource: CopyValue, +} + +impl std::future::Future for ResourceFuture { type Output = (); fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { - match self.waker.clone().poll_unpin(cx) { - std::task::Poll::Ready(_) => std::task::Poll::Ready(()), - std::task::Poll::Pending => std::task::Poll::Pending, + let myself = self.get_mut(); + let mut handle = myself.resource.write(); + if !handle.task.paused() { + std::task::Poll::Ready(()) + } else { + handle.wakers.push(cx.waker().clone()); + std::task::Poll::Pending } } } + +/// A type alias for global stores +/// +/// # Example +/// ```rust, no_run +/// use dioxus::prelude::*; +/// +/// static DOGS: GlobalResource, _> = Global::new(|| async { +/// Ok(reqwest::get("https://dog.ceo/api/breeds/list/all") +/// .await? +/// .text() +/// .await?) +/// }); +/// +/// fn app() -> Element { +/// let dogs = DOGS.resolve(); +/// match dogs.result() { +/// None => rsx! { "Loading..." }, +/// Some(Ok(dogs)) => rsx! { "Dogs: {dogs}" }, +/// Some(Err(err)) => rsx! { "Error: {err}" }, +/// } +/// } +/// ``` +pub type GlobalResource = Global, F>; + +impl InitializeFromFunction for PendingResource +where + F: Future + 'static, + T: 'static, +{ + #[track_caller] + fn initialize_from_function(f: fn() -> F) -> Self { + Resource::new(f) + } +} diff --git a/packages/playwright-tests/default-features-disabled/src/main.rs b/packages/playwright-tests/default-features-disabled/src/main.rs index ded9a7e7a9..67f10dafcf 100644 --- a/packages/playwright-tests/default-features-disabled/src/main.rs +++ b/packages/playwright-tests/default-features-disabled/src/main.rs @@ -9,7 +9,7 @@ fn main() { } fn app() -> Element { - let server_features = use_server_future(get_server_features)?.unwrap().unwrap(); + let server_features = use_server_future(get_server_features)?.unwrap(); let mut client_features = use_signal(Vec::new); use_effect(move || { diff --git a/packages/playwright-tests/fullstack-errors/src/main.rs b/packages/playwright-tests/fullstack-errors/src/main.rs index b016f3acbc..f13c215ccf 100644 --- a/packages/playwright-tests/fullstack-errors/src/main.rs +++ b/packages/playwright-tests/fullstack-errors/src/main.rs @@ -60,7 +60,7 @@ pub fn ErrorFallbackButton() -> Element { #[component] pub fn ThrowsError() -> Element { - use_server_future(server_error)?.unwrap()?; + use_server_future(server_error)?()?; rsx! { "success" } diff --git a/packages/playwright-tests/fullstack-hydration-order/src/main.rs b/packages/playwright-tests/fullstack-hydration-order/src/main.rs index 93fbc59de2..00c11abde7 100644 --- a/packages/playwright-tests/fullstack-hydration-order/src/main.rs +++ b/packages/playwright-tests/fullstack-hydration-order/src/main.rs @@ -28,7 +28,7 @@ pub fn Home() -> Element { fn MyStrings() -> Element { let strings = use_server_future(get_strings)?; let data = match &*strings.read() { - Some(Ok(data)) => data.clone(), + Ok(data) => data.clone(), _ => vec![], }; @@ -50,7 +50,7 @@ pub async fn get_strings() -> Result, ServerFnError> { fn MyFloats() -> Element { let floats = use_server_future(get_floats)?; let data = match &*floats.read() { - Some(Ok(data)) => data.clone(), + Ok(data) => data.clone(), _ => vec![], }; diff --git a/packages/playwright-tests/fullstack-routing/src/main.rs b/packages/playwright-tests/fullstack-routing/src/main.rs index 1ea1b2b631..a78880bd6d 100644 --- a/packages/playwright-tests/fullstack-routing/src/main.rs +++ b/packages/playwright-tests/fullstack-routing/src/main.rs @@ -61,7 +61,7 @@ fn ThrowsAsyncError() -> Element { Err(ServerFnError::new("Async error from a server function")) } - use_server_future(error_after_delay)?().unwrap()?; + use_server_future(error_after_delay)?()?; rsx! { "Hello, world!" } diff --git a/packages/playwright-tests/fullstack/src/main.rs b/packages/playwright-tests/fullstack/src/main.rs index b36fc8e028..d66067bab7 100644 --- a/packages/playwright-tests/fullstack/src/main.rs +++ b/packages/playwright-tests/fullstack/src/main.rs @@ -66,7 +66,7 @@ fn OnMounted() -> Element { #[component] fn DefaultServerFnCodec() -> Element { let resource = use_server_future(|| get_server_data_empty_vec(Vec::new()))?; - let empty_vec = resource.unwrap().unwrap(); + let empty_vec = resource.unwrap(); assert!(empty_vec.is_empty()); rsx! {} @@ -141,7 +141,7 @@ fn Errors() -> Element { #[component] pub fn ThrowsError() -> Element { - use_server_future(server_error)?.unwrap()?; + use_server_future(server_error)?()?; rsx! { "success" } diff --git a/packages/playwright-tests/nested-suspense/src/lib.rs b/packages/playwright-tests/nested-suspense/src/lib.rs index 4d7d351f31..520207d4e9 100644 --- a/packages/playwright-tests/nested-suspense/src/lib.rs +++ b/packages/playwright-tests/nested-suspense/src/lib.rs @@ -40,9 +40,7 @@ fn MessageWithLoader(id: usize) -> Element { #[component] fn LoadTitle() -> Element { - let title = use_server_future(move || server_content(0))?() - .unwrap() - .unwrap(); + let title = use_server_future(move || server_content(0))?().unwrap(); rsx! { "title loaded" @@ -52,9 +50,7 @@ fn LoadTitle() -> Element { #[component] fn Message(id: usize) -> Element { - let message = use_server_future(move || server_content(id))?() - .unwrap() - .unwrap(); + let message = use_server_future(move || server_content(id))?().unwrap(); rsx! { h2 { diff --git a/packages/playwright-tests/suspense-carousel/src/main.rs b/packages/playwright-tests/suspense-carousel/src/main.rs index 4fe52be795..5a4f931a54 100644 --- a/packages/playwright-tests/suspense-carousel/src/main.rs +++ b/packages/playwright-tests/suspense-carousel/src/main.rs @@ -62,8 +62,7 @@ fn SuspendedComponent(id: i32) -> Element { let resolved_on = use_server_future(move || async move { async_std::task::sleep(std::time::Duration::from_secs(1)).await; ResolvedOn::CURRENT - })?() - .unwrap(); + })?(); let mut count = use_signal(|| 0); @@ -93,8 +92,8 @@ fn NestedSuspendedComponent(id: i32) -> Element { let resolved_on = use_server_future(move || async move { async_std::task::sleep(std::time::Duration::from_secs(1)).await; ResolvedOn::CURRENT - })?() - .unwrap(); + })?(); + let mut count = use_signal(|| 0); rsx! { div { diff --git a/packages/stores-macro/src/derive.rs b/packages/stores-macro/src/derive.rs index d30c757e37..aa79b21747 100644 --- a/packages/stores-macro/src/derive.rs +++ b/packages/stores-macro/src/derive.rs @@ -60,7 +60,7 @@ fn derive_store_struct( let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - let (extension_impl_generics, extension_generics, extension_where_clause) = + let (extension_impl_generics, extension_ty_generics, extension_where_clause) = extension_generics.split_for_impl(); // We collect the definitions and implementations for the extension trait methods along with the types of the fields in the transposed struct @@ -85,7 +85,7 @@ fn derive_store_struct( let definition = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where Self: ::std::marker::Copy; + ) -> #transposed_name #extension_ty_generics where Self: ::std::marker::Copy; }; definitions.push(definition); let field_names = fields @@ -108,7 +108,7 @@ fn derive_store_struct( let implementation = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where Self: ::std::marker::Copy { + ) -> #transposed_name #extension_ty_generics where Self: ::std::marker::Copy { // Convert each field into the corresponding store #( let #field_names = self.#field_names(); @@ -119,17 +119,15 @@ fn derive_store_struct( implementations.push(implementation); // Generate the transposed struct definition - let transposed_struct = match &structure.fields { - Fields::Named(_) => { - quote! { #visibility struct #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_fields),*} } - } - Fields::Unnamed(_) => { - quote! { #visibility struct #transposed_name #extension_impl_generics (#(#transposed_fields),*) #extension_where_clause; } - } - Fields::Unit => { - quote! {#visibility struct #transposed_name #extension_impl_generics #extension_where_clause;} - } - }; + let transposed_struct = transposed_struct( + visibility, + struct_name, + &transposed_name, + structure, + generics, + &extension_generics, + &transposed_fields, + ); // Expand to the extension trait and its implementation for the store alongside the transposed struct Ok(quote! { @@ -139,7 +137,7 @@ fn derive_store_struct( )* } - impl #extension_impl_generics #extension_trait_name #extension_generics for dioxus_stores::Store<#struct_name #ty_generics, __Lens> #extension_where_clause { + impl #extension_impl_generics #extension_trait_name #extension_ty_generics for dioxus_stores::Store<#struct_name #ty_generics, __Lens> #extension_where_clause { #( #implementations )* @@ -149,6 +147,72 @@ fn derive_store_struct( }) } +fn field_type_generic(field: &Field, generics: &syn::Generics) -> bool { + generics.type_params().any(|param| { + matches!(&field.ty, syn::Type::Path(type_path) if type_path.path.is_ident(¶m.ident)) + }) +} + +fn transposed_struct( + visibility: &syn::Visibility, + struct_name: &Ident, + transposed_name: &Ident, + structure: &DataStruct, + generics: &syn::Generics, + extension_generics: &syn::Generics, + transposed_fields: &[TokenStream2], +) -> TokenStream2 { + let (extension_impl_generics, _, extension_where_clause) = extension_generics.split_for_impl(); + // Only use a type alias if: + // - There are no bounds on the type generics + // - All fields are generic types + let use_type_alias = generics.type_params().all(|param| param.bounds.is_empty()) + && structure + .fields + .iter() + .all(|field| field_type_generic(field, generics)); + if use_type_alias { + let generics = transpose_generics(struct_name, generics); + return quote! {#visibility type #transposed_name #extension_impl_generics = #struct_name #generics;}; + } + match &structure.fields { + Fields::Named(fields) => { + let fields = fields.named.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + let ident = &f.ident; + let colon = f.colon_token.as_ref(); + quote! { #vis #ident #colon #t } + }); + quote! { + #visibility struct #transposed_name #extension_impl_generics #extension_where_clause { + #( + #fields + ),* + } + } + } + Fields::Unnamed(fields) => { + let fields = fields.unnamed.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + quote! { #vis #t } + }); + quote! { + #visibility struct #transposed_name #extension_impl_generics ( + #( + #fields + ),* + ) + #extension_where_clause; + } + } + Fields::Unit => { + quote! {#visibility struct #transposed_name #extension_impl_generics #extension_where_clause} + } + } +} + fn generate_field_methods( field_index: usize, field: &syn::Field, @@ -158,9 +222,7 @@ fn generate_field_methods( definitions: &mut Vec, implementations: &mut Vec, ) { - let vis = &field.vis; let field_name = &field.ident; - let colon = field.colon_token.as_ref(); // When we map the field, we need to use either the field name for named fields or the index for unnamed fields. let field_accessor = field_name.as_ref().map_or_else( @@ -171,7 +233,7 @@ fn generate_field_methods( let field_type = &field.ty; let store_type = mapped_type(struct_name, ty_generics, field_type); - transposed_fields.push(quote! { #vis #field_name #colon #store_type }); + transposed_fields.push(store_type.clone()); // Each field gets its own reactive scope within the child based on the field's index let ordinal = LitInt::new(&field_index.to_string(), field.span()); @@ -218,7 +280,7 @@ fn derive_store_enum( let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - let (extension_impl_generics, extension_generics, extension_where_clause) = + let (extension_impl_generics, extension_ty_generics, extension_where_clause) = extension_generics.split_for_impl(); // We collect the definitions and implementations for the extension trait methods along with the types of the fields in the transposed enum @@ -249,14 +311,11 @@ fn derive_store_enum( let mut transposed_field_selectors = Vec::new(); let fields = &variant.fields; for (i, field) in fields.iter().enumerate() { - let vis = &field.vis; - let field_name = &field.ident; - let colon = field.colon_token.as_ref(); let field_type = &field.ty; let store_type = mapped_type(enum_name, &ty_generics, field_type); // Push the field for the transposed enum - transposed_fields.push(quote! { #vis #field_name #colon #store_type }); + transposed_fields.push(store_type.clone()); // Generate the code to get Store from the enum let select_field = select_enum_variant_field( @@ -321,11 +380,31 @@ fn derive_store_enum( // Push the type definition of the variant to the transposed enum let transposed_variant = match &fields { - Fields::Named(_) => { - quote! { #variant_name {#(#transposed_fields),*} } + Fields::Named(named) => { + let fields = named.named.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + let ident = &f.ident; + let colon = f.colon_token.as_ref(); + quote! { #vis #ident #colon #t } + }); + quote! { #variant_name { + #( + #fields + ),* + } } } - Fields::Unnamed(_) => { - quote! { #variant_name (#(#transposed_fields),*) } + Fields::Unnamed(unnamed) => { + let fields = unnamed.unnamed.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + quote! { #vis #t } + }); + quote! { #variant_name ( + #( + #fields + ),* + ) } } Fields::Unit => { quote! { #variant_name } @@ -337,13 +416,13 @@ fn derive_store_enum( let definition = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where #readable_bounds, Self: ::std::marker::Copy; + ) -> #transposed_name #extension_ty_generics where #readable_bounds, Self: ::std::marker::Copy; }; definitions.push(definition); let implementation = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where #readable_bounds, Self: ::std::marker::Copy { + ) -> #transposed_name #extension_ty_generics where #readable_bounds, Self: ::std::marker::Copy { // We only do a shallow read of the store to get the current variant. We only need to rerun // this match when the variant changes, not when the fields change self.selector().track_shallow(); @@ -358,7 +437,23 @@ fn derive_store_enum( }; implementations.push(implementation); - let transposed_enum = quote! { #visibility enum #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_variants),*} }; + // Only use a type alias if: + // - There are no bounds on the type generics + // - All fields are generic types + let use_type_alias = generics.type_params().all(|param| param.bounds.is_empty()) + && structure + .variants + .iter() + .flat_map(|variant| variant.fields.iter()) + .all(|field| field_type_generic(field, generics)); + + let transposed_enum = if use_type_alias { + let generics = transpose_generics(enum_name, generics); + + quote! {#visibility type #transposed_name #extension_generics = #enum_name #generics;} + } else { + quote! { #visibility enum #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_variants),*} } + }; // Expand to the extension trait and its implementation for the store alongside the transposed enum Ok(quote! { @@ -368,7 +463,7 @@ fn derive_store_enum( )* } - impl #extension_impl_generics #extension_trait_name #extension_generics for dioxus_stores::Store<#enum_name #ty_generics, __Lens> #extension_where_clause { + impl #extension_impl_generics #extension_trait_name #extension_ty_generics for dioxus_stores::Store<#enum_name #ty_generics, __Lens> #extension_where_clause { #( #implementations )* @@ -491,3 +586,31 @@ fn mapped_type( let write_type = quote! { dioxus_stores::macro_helpers::dioxus_signals::MappedMutSignal<#field_type, __Lens, fn(&#item #ty_generics) -> &#field_type, fn(&mut #item #ty_generics) -> &mut #field_type> }; quote! { dioxus_stores::Store<#field_type, #write_type> } } + +/// Take the generics from the original type with only generic fields into the generics for the transposed type +fn transpose_generics(name: &Ident, generics: &syn::Generics) -> TokenStream2 { + let (_, ty_generics, _) = generics.split_for_impl(); + let mut transposed_generics = generics.clone(); + let mut generics = Vec::new(); + for gen in transposed_generics.params.iter_mut() { + match gen { + // Map type generics into Store> + syn::GenericParam::Type(type_param) => { + let ident = &type_param.ident; + let ty = mapped_type(name, &ty_generics, &parse_quote!(#ident)); + generics.push(ty); + } + // Forward const and lifetime generics as-is + syn::GenericParam::Const(const_param) => { + let ident = &const_param.ident; + generics.push(quote! { #ident }); + } + syn::GenericParam::Lifetime(lt_param) => { + let ident = <_param.lifetime; + generics.push(quote! { #ident }); + } + } + } + + quote!(<#(#generics),*> ) +} diff --git a/packages/stores/src/impls/result.rs b/packages/stores/src/impls/result.rs index 312a590e65..dbf016c924 100644 --- a/packages/stores/src/impls/result.rs +++ b/packages/stores/src/impls/result.rs @@ -118,10 +118,7 @@ where /// None => panic!("Expected Err"), /// } /// ``` - pub fn err(self) -> Option> - where - Lens: Writable> + 'static, - { + pub fn err(self) -> Option> { self.is_err().then(|| { let map: fn(&Result) -> &E = |value| match value { Ok(_) => panic!("Tried to access `err` on an Ok value"), @@ -149,10 +146,7 @@ where /// } /// ``` #[allow(clippy::result_large_err)] - pub fn transpose(self) -> Result, MappedStore> - where - Lens: Writable> + 'static, - { + pub fn transpose(self) -> Result, MappedStore> { if self.is_ok() { let map: fn(&Result) -> &T = |value| match value { Ok(t) => t, @@ -188,7 +182,6 @@ where /// ``` pub fn unwrap(self) -> MappedStore where - Lens: Writable> + 'static, E: Debug, { self.transpose().unwrap() diff --git a/packages/stores/src/store.rs b/packages/stores/src/store.rs index cfc420ed54..cc75dc1c17 100644 --- a/packages/stores/src/store.rs +++ b/packages/stores/src/store.rs @@ -14,7 +14,7 @@ use dioxus_signals::{ use std::marker::PhantomData; /// A type alias for a store that has been mapped with a function -pub(crate) type MappedStore< +pub type MappedStore< T, Lens, F = fn(&::Target) -> &T, diff --git a/packages/stores/tests/marco.rs b/packages/stores/tests/marco.rs index e972c43bd6..d177b0023e 100644 --- a/packages/stores/tests/marco.rs +++ b/packages/stores/tests/marco.rs @@ -138,6 +138,20 @@ mod macro_tests { store.check(); } + fn derive_generic_struct_transposed_passthrough() { + #[derive(Store)] + struct Item { + contents: T, + } + + let mut store = use_store(|| Item::<0, _> { + contents: "Learn about stores".to_string(), + }); + + let Item { contents } = store.transpose(); + let contents: String = contents(); + } + fn derive_tuple() { #[derive(Store, PartialEq, Clone, Debug)] struct Item(bool, String); @@ -342,4 +356,28 @@ mod macro_tests { } } } + + fn derive_generic_enum_transpose_passthrough() { + #[derive(Store, PartialEq, Clone, Debug)] + #[non_exhaustive] + enum Enum { + Foo, + Bar(T), + BarFoo { foo: T }, + } + + let mut store = use_store(|| Enum::<0, _>::Bar("Hello".to_string())); + + let transposed = store.transpose(); + use Enum::*; + match transposed { + Enum::Foo => {} + Bar(bar) => { + let bar: String = bar(); + } + BarFoo { foo } => { + let foo: String = foo(); + } + } + } }