@@ -3,16 +3,16 @@ use std::collections::HashSet;
3
3
use std:: sync:: Arc ;
4
4
use thiserror:: Error ;
5
5
6
- use crate :: event:: EventExt ;
7
6
use crate :: stream:: EventStream ;
8
7
use crate :: subscriber:: AgentSubscriber ;
9
- use ag_ui_core :: event :: Event ;
8
+
10
9
use ag_ui_core:: types:: context:: Context ;
11
10
use ag_ui_core:: types:: ids:: { AgentId , MessageId , RunId , ThreadId } ;
12
11
use ag_ui_core:: types:: input:: RunAgentInput ;
13
12
use ag_ui_core:: types:: message:: Message ;
14
13
use ag_ui_core:: types:: tool:: Tool ;
15
- use ag_ui_core:: { FwdProps , JsonValue , State } ;
14
+ use ag_ui_core:: { AgentState , FwdProps , JsonValue } ;
15
+ use crate :: event_handler:: EventHandler ;
16
16
17
17
#[ derive( Debug , Clone ) ]
18
18
pub struct AgentConfig < StateT = JsonValue > {
@@ -42,11 +42,13 @@ where
42
42
43
43
/// Parameters for running an agent.
44
44
#[ derive( Debug , Clone , Default ) ]
45
- pub struct RunAgentParams < FwdPropsT = JsonValue > {
45
+ pub struct RunAgentParams < StateT : AgentState , FwdPropsT = JsonValue > {
46
46
pub run_id : Option < RunId > ,
47
47
pub tools : Option < Vec < Tool > > ,
48
48
pub context : Option < Vec < Context > > ,
49
49
pub forwarded_props : Option < FwdPropsT > ,
50
+ pub messages : Vec < Message > ,
51
+ pub state : StateT ,
50
52
}
51
53
52
54
#[ derive( Debug , Clone ) ]
@@ -55,6 +57,8 @@ pub struct RunAgentResult {
55
57
pub new_messages : Vec < Message > ,
56
58
}
57
59
60
+ pub type AgentRunState < StateT , FwdPropsT > = RunAgentInput < StateT , FwdPropsT > ;
61
+
58
62
#[ derive( Debug , Clone ) ]
59
63
pub struct AgentStateMutation < StateT = JsonValue > {
60
64
pub messages : Option < Vec < Message > > ,
@@ -90,227 +94,105 @@ pub enum AgentError {
90
94
#[ async_trait:: async_trait]
91
95
pub trait Agent < StateT = JsonValue , FwdPropsT = JsonValue > : Send + Sync
92
96
where
93
- StateT : State ,
97
+ StateT : AgentState ,
94
98
FwdPropsT : FwdProps ,
95
99
{
96
- fn run < ' a > ( & ' a self , input : & ' a RunAgentInput < StateT , FwdPropsT > ) -> EventStream < ' a > ;
97
-
98
- // Idiomatic accessors for agent state.
99
- fn agent_id ( & self ) -> Option < & AgentId > ;
100
- fn agent_id_mut ( & mut self ) -> & mut Option < AgentId > ;
101
- fn description ( & self ) -> & str ;
102
- fn description_mut ( & mut self ) -> & mut String ;
103
- fn thread_id ( & self ) -> & ThreadId ;
104
- fn thread_id_mut ( & mut self ) -> & mut ThreadId ;
105
- fn messages ( & self ) -> & [ Message ] ;
106
- fn messages_mut ( & mut self ) -> & mut Vec < Message > ;
107
- fn state ( & self ) -> & StateT ;
108
- fn state_mut ( & mut self ) -> & mut StateT ;
109
- fn subscribers ( & self ) -> & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ;
110
- fn subscribers_mut ( & mut self ) -> & mut Vec < Arc < dyn AgentSubscriber < StateT , FwdPropsT > > > ;
111
-
112
- /// Adds a subscriber to the agent.
113
- fn add_subscriber ( & mut self , subscriber : Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ) {
114
- self . subscribers_mut ( ) . push ( subscriber) ;
115
- }
100
+ async fn run (
101
+ & self ,
102
+ input : & RunAgentInput < StateT , FwdPropsT > ,
103
+ ) -> Result < EventStream < ' async_trait , StateT > , AgentError > ;
116
104
117
105
/// The main execution method, containing the full pipeline logic.
118
106
async fn run_agent (
119
- & mut self ,
120
- params : & RunAgentParams < FwdPropsT > ,
121
- subscriber : Option < Arc < dyn AgentSubscriber < StateT , FwdPropsT > > > ,
107
+ & self ,
108
+ params : & RunAgentParams < StateT , FwdPropsT > ,
109
+ subscribers : Vec < Arc < dyn AgentSubscriber < StateT , FwdPropsT > > > ,
122
110
) -> Result < RunAgentResult , AgentError > {
123
- if self . agent_id ( ) . is_none ( ) {
124
- * self . agent_id_mut ( ) = Some ( AgentId :: new ( ) ) ;
125
- }
126
-
127
- let mut subscribers = self . subscribers ( ) . to_vec ( ) ;
128
- if let Some ( sub) = subscriber {
129
- subscribers. push ( sub) ;
130
- }
111
+ // TODO: Use Agent ID?
112
+ let agent_id = AgentId :: random ( ) ;
113
+
114
+ let input = RunAgentInput {
115
+ thread_id : ThreadId :: random ( ) ,
116
+ run_id : params. run_id . clone ( ) . unwrap_or_else ( RunId :: random) ,
117
+ state : params. state . clone ( ) ,
118
+ messages : params. messages . clone ( ) ,
119
+ tools : params. tools . clone ( ) . unwrap_or_default ( ) ,
120
+ context : params. context . clone ( ) . unwrap_or_default ( ) ,
121
+ // TODO: Find suitable default value
122
+ forwarded_props : params. forwarded_props . clone ( ) . unwrap ( ) ,
123
+ } ;
124
+ let current_message_ids: HashSet < & MessageId > =
125
+ params. messages . iter ( ) . map ( |m| m. id ( ) ) . collect ( ) ;
131
126
132
- let input = self . prepare_run_agent_input ( params) ;
133
- let messages = self . messages ( ) . to_vec ( ) ;
134
- let current_message_ids: HashSet < & MessageId > = messages. iter ( ) . map ( |m| m. id ( ) ) . collect ( ) ;
135
- let mut result_val = JsonValue :: Null ;
127
+ // Initialize event handler with the current state
128
+ let mut event_handler = EventHandler :: new (
129
+ params. messages . clone ( ) ,
130
+ params. state . clone ( ) ,
131
+ & input,
132
+ subscribers,
133
+ ) ;
136
134
137
- let mut stream = self . run ( & input) . fuse ( ) ;
135
+ let mut stream = self . run ( & input) . await ? . fuse ( ) ;
138
136
139
137
while let Some ( event_result) = stream. next ( ) . await {
140
138
match event_result {
141
139
Ok ( event) => {
142
- let ( mutation, value) = event
143
- . apply_and_process_event ( & input, & messages, & input. state , & subscribers)
144
- . await ?;
145
- result_val = JsonValue :: from ( value) ;
140
+ let mutation = event_handler. handle_event ( & event) . await ?;
141
+ event_handler. apply_mutation ( mutation) . await ?;
146
142
}
147
143
Err ( e) => {
148
- // self .on_error(&input, &e, &subscribers ).await?;
144
+ event_handler . on_error ( & e ) . await ?;
149
145
return Err ( e) ;
150
146
}
151
147
}
152
148
}
153
149
154
- // self.on_finalize(&input, &subscribers).await?;
150
+ // Finalize the run
151
+ event_handler. on_finalize ( ) . await ?;
155
152
156
- let new_messages = self
157
- . messages ( )
153
+ // Collect new messages
154
+ let new_messages = event_handler
155
+ . messages
158
156
. iter ( )
159
157
. filter ( |m| !current_message_ids. contains ( & m. id ( ) ) )
160
158
. cloned ( )
161
159
. collect ( ) ;
162
160
163
161
Ok ( RunAgentResult {
164
- result : result_val ,
162
+ result : event_handler . result ,
165
163
new_messages,
166
164
} )
167
165
}
168
166
169
- /// Helper to construct the input for the `run` method.
170
- fn prepare_run_agent_input (
171
- & self ,
172
- params : & RunAgentParams < FwdPropsT > ,
173
- ) -> RunAgentInput < StateT , FwdPropsT > {
174
- RunAgentInput {
175
- thread_id : self . thread_id ( ) . clone ( ) ,
176
- run_id : params. run_id . clone ( ) . unwrap_or_else ( || RunId :: new ( ) ) ,
177
- state : self . state ( ) . clone ( ) ,
178
- messages : self . messages ( ) . to_vec ( ) ,
179
- tools : params. tools . clone ( ) . unwrap_or_default ( ) ,
180
- context : params. context . clone ( ) . unwrap_or_default ( ) ,
181
- // TODO: Find suitable default value
182
- forwarded_props : params. forwarded_props . clone ( ) . unwrap ( ) ,
183
- }
184
- }
185
-
186
- /// Processes a single event, applying mutations and notifying subscribers.
187
- /// Returns the final result if the event is `Done`.
188
- async fn apply_and_process_event (
189
- & mut self ,
190
- event : Event ,
191
- input : & RunAgentInput < StateT , FwdPropsT > ,
192
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
193
- ) -> Result < Option < JsonValue > , AgentError > {
194
- // This is a simplified stand-in for the logic from `defaultApplyEvents` in TS.
195
- // A full implementation would handle each event type to create the correct state mutation.
196
- let ( mutation, result) = match event {
197
- Event :: RunFinished ( e) => {
198
- for sub in subscribers {
199
- sub. on_run_finished (
200
- & e. result . clone ( ) . unwrap ( ) ,
201
- self . messages ( ) ,
202
- self . state ( ) ,
203
- input,
204
- )
205
- . await ?;
206
- }
207
- ( AgentStateMutation :: default ( ) , e. result )
208
- }
209
- // In a real implementation, other events like Text, ToolCall, etc.,
210
- // would create mutations to update messages and state.
211
- _ => ( AgentStateMutation :: default ( ) , None ) ,
212
- } ;
213
-
214
- self . apply_mutation ( mutation, input, subscribers) . await ?;
215
- Ok ( result)
216
- }
217
-
218
- async fn on_initialize (
219
- & mut self ,
220
- input : & mut RunAgentInput < StateT , FwdPropsT > ,
221
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
222
- ) -> Result < ( ) , AgentError > {
223
- for subscriber in subscribers {
224
- let mutation = subscriber
225
- . on_run_initialized ( self . messages ( ) , self . state ( ) , input)
226
- . await ?;
227
-
228
- if mutation. messages . is_some ( ) || mutation. state . is_some ( ) {
229
- if let Some ( ref messages) = mutation. messages {
230
- input. messages = messages. clone ( ) ;
231
- }
232
- if let Some ( ref state) = mutation. state {
233
- input. state = state. clone ( ) ;
234
- }
235
- self . apply_mutation ( mutation, input, subscribers) . await ?;
236
- }
237
- }
238
- Ok ( ( ) )
239
- }
240
-
241
- async fn on_error (
242
- & mut self ,
243
- input : & RunAgentInput < StateT , FwdPropsT > ,
244
- error : & AgentError ,
245
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
246
- ) -> Result < ( ) , AgentError > {
247
- for subscriber in subscribers {
248
- let mutation = subscriber
249
- . on_run_failed ( error, self . messages ( ) , self . state ( ) , input)
250
- . await ?;
251
-
252
- self . apply_mutation ( mutation, input, subscribers) . await ?;
253
- }
254
- Ok ( ( ) )
255
- }
256
-
257
- async fn on_finalize (
258
- & mut self ,
259
- input : & RunAgentInput < StateT , FwdPropsT > ,
260
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
261
- ) -> Result < ( ) , AgentError > {
262
- for subscriber in subscribers {
263
- let mutation = subscriber
264
- . on_run_finalized ( self . messages ( ) , self . state ( ) , input)
265
- . await ?;
266
-
267
- self . apply_mutation ( mutation, input, subscribers) . await ?;
268
- }
269
- Ok ( ( ) )
270
- }
271
-
272
- async fn apply_mutation (
273
- & mut self ,
274
- mutation : AgentStateMutation < StateT > ,
275
- input : & RunAgentInput < StateT , FwdPropsT > ,
276
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
277
- ) -> Result < ( ) , AgentError > {
278
- if let Some ( messages) = mutation. messages {
279
- * self . messages_mut ( ) = messages;
280
- self . notify_messages_changed ( input, subscribers) . await ?;
281
- }
282
-
283
- if let Some ( state) = mutation. state {
284
- * self . state_mut ( ) = state;
285
- self . notify_state_changed ( input, subscribers) . await ?;
286
- }
287
-
288
- Ok ( ( ) )
289
- }
290
-
291
- async fn notify_messages_changed (
292
- & self ,
293
- input : & RunAgentInput < StateT , FwdPropsT > ,
294
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
295
- ) -> Result < ( ) , AgentError > {
296
- for subscriber in subscribers {
297
- subscriber
298
- . on_messages_changed ( self . messages ( ) , self . state ( ) , input)
299
- . await ?;
300
- }
301
- Ok ( ( ) )
302
- }
303
-
304
- async fn notify_state_changed (
305
- & self ,
306
- input : & RunAgentInput < StateT , FwdPropsT > ,
307
- subscribers : & [ Arc < dyn AgentSubscriber < StateT , FwdPropsT > > ] ,
308
- ) -> Result < ( ) , AgentError > {
309
- for subscriber in subscribers {
310
- subscriber
311
- . on_state_changed ( self . messages ( ) , self . state ( ) , input)
312
- . await ?;
313
- }
314
- Ok ( ( ) )
315
- }
167
+ // Helper function to run subscribers that can return a mutation
168
+ // async fn run_subscribers_with_mutation<F, Fut>(
169
+ // &self,
170
+ // subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
171
+ // mut callback: F,
172
+ // ) -> Result<AgentStateMutation<StateT>, AgentError>
173
+ // where
174
+ // F: FnMut(&Arc<dyn AgentSubscriber<StateT, FwdPropsT>>) -> Fut + Send,
175
+ // Fut: std::future::Future<Output = Result<AgentStateMutation<StateT>, AgentError>>,
176
+ // {
177
+ // let mut result = AgentStateMutation::default();
178
+ //
179
+ // for subscriber in subscribers {
180
+ // let mutation = callback(subscriber).await?;
181
+ //
182
+ // if mutation.messages.is_some() {
183
+ // result.messages = mutation.messages;
184
+ // }
185
+ //
186
+ // if mutation.state.is_some() {
187
+ // result.state = mutation.state;
188
+ // }
189
+ //
190
+ // if mutation.stop_propagation {
191
+ // result.stop_propagation = true;
192
+ // break;
193
+ // }
194
+ // }
195
+ //
196
+ // Ok(result)
197
+ // }
316
198
}
0 commit comments