Skip to content

Commit a01f0c6

Browse files
mariusaefacebook-github-bot
authored andcommitted
Actor derive macro (#837)
Summary: Pull Request resolved: #837 We have many trivial `Actor` implementations. (Usually, the the actor's handlers are more involved.) We add a derive macro for actors, supporting two patterns: 1) "empty params", where we use the actors Default::default implementation to create a new actor; 2) "passthrough actor", where the params themselves are simply an actor instance which we immediately return This change introduces the macro and adjusts usages throughout our code base. ghstack-source-id: 302711475 exported-using-ghexport Reviewed By: vidhyav Differential Revision: D80139026 fbshipit-source-id: af2ad6a67038d6decc0e2884f7cbee70481de18a
1 parent 7a10f7a commit a01f0c6

File tree

16 files changed

+132
-163
lines changed

16 files changed

+132
-163
lines changed

controller/src/lib.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ impl ControllerMessageHandler for ControllerActor {
527527
// Randomly pick a failed proc as the failed actor.
528528
let (_, failed_state) = world_state.procs.iter().next().unwrap();
529529
let (failed_actor, failure_reason) =
530-
failed_state.failed_actors.iter().next().map_or_else(
530+
failed_state.failed_actors.first().map_or_else(
531531
|| {
532532
let proc_id = &failed_state.proc_id;
533533
(
@@ -1808,23 +1808,14 @@ mod tests {
18081808
Panic(String),
18091809
}
18101810

1811-
#[derive(Debug)]
1811+
#[derive(Debug, Default, Actor)]
18121812
#[hyperactor::export(
18131813
handlers = [
18141814
PanickingMessage,
18151815
],
18161816
)]
18171817
struct PanickingActor;
18181818

1819-
#[async_trait]
1820-
impl Actor for PanickingActor {
1821-
type Params = ();
1822-
1823-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
1824-
Ok(Self)
1825-
}
1826-
}
1827-
18281819
#[async_trait]
18291820
#[hyperactor::forward(PanickingMessage)]
18301821
impl PanickingMessageHandler for PanickingActor {

hyper/src/commands/demo.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ enum DemoMessage {
210210
Error(String, #[reply] OncePortRef<()>),
211211
}
212212

213-
#[derive(Debug)]
213+
#[derive(Debug, Default, Actor)]
214214
#[hyperactor::export(
215215
spawn = true,
216216
handlers = [
@@ -219,15 +219,6 @@ enum DemoMessage {
219219
)]
220220
struct DemoActor;
221221

222-
#[async_trait]
223-
impl Actor for DemoActor {
224-
type Params = ();
225-
226-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
227-
Ok(Self)
228-
}
229-
}
230-
231222
#[async_trait]
232223
#[forward(DemoMessage)]
233224
impl DemoMessageHandler for DemoActor {

hyperactor/example/derive.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ enum ShoppingList {
3535
}
3636

3737
// Define an actor.
38-
#[derive(Debug)]
38+
#[derive(Debug, Actor, Default)]
3939
#[hyperactor::export(
4040
spawn = true,
4141
handlers = [
@@ -44,15 +44,6 @@ enum ShoppingList {
4444
)]
4545
struct ShoppingListActor(HashSet<String>);
4646

47-
#[async_trait]
48-
impl Actor for ShoppingListActor {
49-
type Params = ();
50-
51-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
52-
Ok(Self(HashSet::new()))
53-
}
54-
}
55-
5647
// ShoppingListHandler is the trait generated by derive(Handler) above.
5748
// We implement the trait here for the actor, defining a handler for
5849
// each ShoppingList message.

hyperactor/src/actor.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ mod tests {
658658
use tokio::time::timeout;
659659

660660
use super::*;
661+
use crate as hyperactor;
662+
use crate::Actor;
661663
use crate::Mailbox;
662664
use crate::OncePortHandle;
663665
use crate::PortRef;
@@ -666,7 +668,7 @@ mod tests {
666668
use crate::test_utils::pingpong::PingPongActor;
667669
use crate::test_utils::pingpong::PingPongActorParams;
668670
use crate::test_utils::pingpong::PingPongMessage;
669-
use crate::test_utils::proc_supervison::ProcSupervisionCoordinator;
671+
use crate::test_utils::proc_supervison::ProcSupervisionCoordinator; // for macros
670672

671673
#[derive(Debug)]
672674
struct EchoActor(PortRef<u64>);
@@ -983,18 +985,9 @@ mod tests {
983985

984986
#[tokio::test]
985987
async fn test_actor_handle_downcast() {
986-
#[derive(Debug)]
988+
#[derive(Debug, Default, Actor)]
987989
struct NothingActor;
988990

989-
#[async_trait]
990-
impl Actor for NothingActor {
991-
type Params = ();
992-
993-
async fn new(_: ()) -> Result<Self, anyhow::Error> {
994-
Ok(Self)
995-
}
996-
}
997-
998991
// Just test that we can round-trip the handle through a downcast.
999992

1000993
let proc = Proc::local();

hyperactor/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ pub use cityhasher;
106106
#[doc(hidden)]
107107
pub use dashmap; // For intern_typename!
108108
pub use data::Named;
109+
#[doc(hidden)]
110+
pub use hyperactor_macros::Actor;
109111
#[doc(inline)]
110112
pub use hyperactor_macros::Bind;
111113
#[doc(inline)]

hyperactor/src/mailbox.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2924,17 +2924,9 @@ mod tests {
29242924
);
29252925
}
29262926

2927-
#[derive(Debug)]
2927+
#[derive(Debug, Default, Actor)]
29282928
struct Foo;
29292929

2930-
#[async_trait]
2931-
impl Actor for Foo {
2932-
type Params = ();
2933-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
2934-
Ok(Self)
2935-
}
2936-
}
2937-
29382930
// Test that a message delivery failure causes the sending actor
29392931
// to stop running.
29402932
#[tokio::test]

hyperactor/src/proc.rs

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,7 @@ mod tests {
18571857
}
18581858
}
18591859

1860-
#[derive(Debug)]
1860+
#[derive(Debug, Default, Actor)]
18611861
#[export]
18621862
struct TestActor;
18631863

@@ -1880,15 +1880,6 @@ mod tests {
18801880
}
18811881
}
18821882

1883-
#[async_trait]
1884-
impl Actor for TestActor {
1885-
type Params = ();
1886-
1887-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
1888-
Ok(Self)
1889-
}
1890-
}
1891-
18921883
#[async_trait]
18931884
#[crate::forward(TestActorMessage)]
18941885
impl TestActorMessageHandler for TestActor {
@@ -2017,23 +2008,14 @@ mod tests {
20172008
rx.await.unwrap();
20182009
}
20192010

2020-
#[derive(Debug)]
2011+
#[derive(Debug, Default, Actor)]
20212012
struct LookupTestActor;
20222013

20232014
#[derive(Handler, HandleClient, Debug)]
20242015
enum LookupTestMessage {
20252016
ActorExists(ActorRef<TestActor>, #[reply] OncePortRef<bool>),
20262017
}
20272018

2028-
#[async_trait]
2029-
impl Actor for LookupTestActor {
2030-
type Params = ();
2031-
2032-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
2033-
Ok(Self)
2034-
}
2035-
}
2036-
20372019
#[async_trait]
20382020
#[crate::forward(LookupTestMessage)]
20392021
impl LookupTestMessageHandler for LookupTestActor {
@@ -2745,18 +2727,9 @@ mod tests {
27452727

27462728
#[tokio::test]
27472729
async fn test_instance() {
2748-
#[derive(Debug)]
2730+
#[derive(Debug, Default, Actor)]
27492731
struct TestActor;
27502732

2751-
#[async_trait]
2752-
impl Actor for TestActor {
2753-
type Params = ();
2754-
2755-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
2756-
Ok(Self)
2757-
}
2758-
}
2759-
27602733
#[async_trait]
27612734
impl Handler<(String, PortRef<String>)> for TestActor {
27622735
async fn handle(
@@ -2834,7 +2807,7 @@ mod tests {
28342807
#[ignore = "until trace recording is turned back on"]
28352808
#[test]
28362809
fn test_handler_logging() {
2837-
#[derive(Debug)]
2810+
#[derive(Debug, Default, Actor)]
28382811
struct LoggingActor;
28392812

28402813
impl LoggingActor {
@@ -2845,15 +2818,6 @@ mod tests {
28452818
}
28462819
}
28472820

2848-
#[async_trait]
2849-
impl Actor for LoggingActor {
2850-
type Params = ();
2851-
2852-
async fn new(_params: ()) -> Result<Self, anyhow::Error> {
2853-
Ok(Self)
2854-
}
2855-
}
2856-
28572821
#[async_trait]
28582822
impl Handler<String> for LoggingActor {
28592823
async fn handle(

hyperactor_macros/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ anyhow = "1.0.98"
3030
async-trait = "0.1.86"
3131
hyperactor = { version = "0.0.0", path = "../hyperactor" }
3232
serde = { version = "1.0.185", features = ["derive", "rc"] }
33+
static_assertions = "1.1.0"
3334
timed_test = { version = "0.0.0", path = "../timed_test" }
3435
tokio = { version = "1.37.0", features = ["full", "test-util", "tracing"] }
3536
tracing = { version = "0.1.41", features = ["attributes", "valuable"] }

hyperactor_macros/src/lib.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,3 +1789,93 @@ pub fn derive_unbind(input: TokenStream) -> TokenStream {
17891789
};
17901790
TokenStream::from(expand)
17911791
}
1792+
1793+
/// Derives the `Actor` trait for a struct. By default, generates an implementation
1794+
/// with no params (`type Params = ()` and `async fn new(_params: ()) -> Result<Self, anyhow::Error>`).
1795+
/// This requires that the Actor implements [`Default`].
1796+
///
1797+
/// If the `#[actor(passthrough)]` attribute is specified, generates an implementation
1798+
/// with where the parameter type is `Self`
1799+
/// (`type Params = Self` and `async fn new(instance: Self) -> Result<Self, anyhow::Error>`).
1800+
///
1801+
/// # Examples
1802+
///
1803+
/// Default behavior:
1804+
/// ```
1805+
/// #[derive(Actor, Default)]
1806+
/// struct MyActor(u64);
1807+
/// ```
1808+
///
1809+
/// Generates:
1810+
/// ```ignore
1811+
/// #[async_trait]
1812+
/// impl Actor for MyActor {
1813+
/// type Params = ();
1814+
///
1815+
/// async fn new(_params: ()) -> Result<Self, anyhow::Error> {
1816+
/// Ok(Default::default())
1817+
/// }
1818+
/// }
1819+
/// ```
1820+
///
1821+
/// Passthrough behavior:
1822+
/// ```
1823+
/// #[derive(Actor, Default)]
1824+
/// #[actor(passthrough)]
1825+
/// struct MyActor(u64);
1826+
/// ```
1827+
///
1828+
/// Generates:
1829+
/// ```ignore
1830+
/// #[async_trait]
1831+
/// impl Actor for MyActor {
1832+
/// type Params = Self;
1833+
///
1834+
/// async fn new(instance: Self) -> Result<Self, anyhow::Error> {
1835+
/// Ok(instance)
1836+
/// }
1837+
/// }
1838+
/// ```
1839+
#[proc_macro_derive(Actor, attributes(actor))]
1840+
pub fn derive_actor(input: TokenStream) -> TokenStream {
1841+
let input = parse_macro_input!(input as DeriveInput);
1842+
let name = &input.ident;
1843+
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1844+
1845+
let is_passthrough = input.attrs.iter().any(|attr| {
1846+
if attr.path().is_ident("actor") {
1847+
if let Ok(meta) = attr.parse_args_with(
1848+
syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated,
1849+
) {
1850+
return meta.iter().any(|ident| ident == "passthrough");
1851+
}
1852+
}
1853+
false
1854+
});
1855+
1856+
let expanded = if is_passthrough {
1857+
quote! {
1858+
#[hyperactor::async_trait::async_trait]
1859+
impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
1860+
type Params = Self;
1861+
1862+
async fn new(instance: Self) -> Result<Self, hyperactor::anyhow::Error> {
1863+
Ok(instance)
1864+
}
1865+
}
1866+
}
1867+
} else {
1868+
quote! {
1869+
#[hyperactor::async_trait::async_trait]
1870+
impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
1871+
type Params = ();
1872+
1873+
async fn new(_params: ()) -> Result<Self, hyperactor::anyhow::Error> {
1874+
Ok(Default::default())
1875+
}
1876+
}
1877+
}
1878+
};
1879+
1880+
TokenStream::from(expanded)
1881+
}

hyperactor_macros/tests/basic.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,22 @@ impl GenericArgMessageHandler<usize> for GenericArgActor {
120120
Ok(())
121121
}
122122
}
123+
124+
#[derive(Actor, Default, Debug)]
125+
struct DefaultActorTest {
126+
value: u64,
127+
}
128+
129+
static_assertions::assert_impl_all!(DefaultActorTest: Actor);
130+
131+
#[derive(Actor, Default, Debug)]
132+
#[actor(passthrough)]
133+
struct PassthroughActorTest {
134+
value: u64,
135+
}
136+
137+
static_assertions::assert_impl_all!(PassthroughActorTest: Actor);
138+
static_assertions::assert_type_eq_all!(
139+
<PassthroughActorTest as hyperactor::Actor>::Params,
140+
PassthroughActorTest
141+
);

0 commit comments

Comments
 (0)