Skip to content

Commit 67ab96d

Browse files
committed
WIP
1 parent c81ceff commit 67ab96d

File tree

19 files changed

+2778
-516
lines changed

19 files changed

+2778
-516
lines changed

rust-sdk/Cargo.lock

Lines changed: 1643 additions & 136 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust-sdk/crates/ag-ui-client/Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,10 @@ serde_json = { workspace = true }
1111
async-trait = "0.1.88"
1212
uuid = { version = "1.17.0", features = ["v4"] }
1313
futures = "0.3.31"
14+
json-patch = "4.0.0"
15+
log = "0.4.27"
16+
reqwest = { version = "0.12.22" , features = ["json", "stream"]}
17+
sse-client = "1.1.1"
18+
19+
[dev-dependencies]
20+
tokio = { version = "1.36.0", features = ["full"] }
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# /// script
2+
# requires-python = ">=3.12"
3+
# dependencies = [
4+
# "uvicorn == 0.34.3",
5+
# "pydantic-ai==0.4.10"
6+
# ]
7+
# ///
8+
9+
import uvicorn
10+
from pydantic_ai import Agent
11+
from pydantic_ai.models.openai import OpenAIModel
12+
from pydantic_ai.providers.openai import OpenAIProvider
13+
14+
model = OpenAIModel(
15+
model_name="llama3.1",
16+
provider=OpenAIProvider(
17+
base_url="http://localhost:11434/v1", api_key="ollama"
18+
),
19+
)
20+
agent = Agent(model)
21+
22+
23+
@agent.tool_plain
24+
def temperature_celsius(city: str) -> float:
25+
return 21.0
26+
27+
28+
@agent.tool_plain
29+
def temperature_fahrenheit(city: str) -> float:
30+
return 69.8
31+
32+
33+
app = agent.to_ag_ui()
34+
35+
if __name__ == "__main__":
36+
uvicorn.run(app, port=3001)

rust-sdk/crates/ag-ui-client/src/agent.rs

Lines changed: 79 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ use std::collections::HashSet;
33
use std::sync::Arc;
44
use thiserror::Error;
55

6-
use crate::event::EventExt;
76
use crate::stream::EventStream;
87
use crate::subscriber::AgentSubscriber;
9-
use ag_ui_core::event::Event;
8+
109
use ag_ui_core::types::context::Context;
1110
use ag_ui_core::types::ids::{AgentId, MessageId, RunId, ThreadId};
1211
use ag_ui_core::types::input::RunAgentInput;
1312
use ag_ui_core::types::message::Message;
1413
use ag_ui_core::types::tool::Tool;
15-
use ag_ui_core::{FwdProps, JsonValue, State};
14+
use ag_ui_core::{AgentState, FwdProps, JsonValue};
15+
use crate::event_handler::EventHandler;
1616

1717
#[derive(Debug, Clone)]
1818
pub struct AgentConfig<StateT = JsonValue> {
@@ -42,11 +42,13 @@ where
4242

4343
/// Parameters for running an agent.
4444
#[derive(Debug, Clone, Default)]
45-
pub struct RunAgentParams<FwdPropsT = JsonValue> {
45+
pub struct RunAgentParams<StateT: AgentState, FwdPropsT = JsonValue> {
4646
pub run_id: Option<RunId>,
4747
pub tools: Option<Vec<Tool>>,
4848
pub context: Option<Vec<Context>>,
4949
pub forwarded_props: Option<FwdPropsT>,
50+
pub messages: Vec<Message>,
51+
pub state: StateT,
5052
}
5153

5254
#[derive(Debug, Clone)]
@@ -55,6 +57,8 @@ pub struct RunAgentResult {
5557
pub new_messages: Vec<Message>,
5658
}
5759

60+
pub type AgentRunState<StateT, FwdPropsT> = RunAgentInput<StateT, FwdPropsT>;
61+
5862
#[derive(Debug, Clone)]
5963
pub struct AgentStateMutation<StateT = JsonValue> {
6064
pub messages: Option<Vec<Message>>,
@@ -90,227 +94,105 @@ pub enum AgentError {
9094
#[async_trait::async_trait]
9195
pub trait Agent<StateT = JsonValue, FwdPropsT = JsonValue>: Send + Sync
9296
where
93-
StateT: State,
97+
StateT: AgentState,
9498
FwdPropsT: FwdProps,
9599
{
96-
fn run<'a>(&'a self, input: &'a RunAgentInput<StateT, FwdPropsT>) -> EventStream<'a>;
97-
98-
// Idiomatic accessors for agent state.
99-
fn agent_id(&self) -> Option<&AgentId>;
100-
fn agent_id_mut(&mut self) -> &mut Option<AgentId>;
101-
fn description(&self) -> &str;
102-
fn description_mut(&mut self) -> &mut String;
103-
fn thread_id(&self) -> &ThreadId;
104-
fn thread_id_mut(&mut self) -> &mut ThreadId;
105-
fn messages(&self) -> &[Message];
106-
fn messages_mut(&mut self) -> &mut Vec<Message>;
107-
fn state(&self) -> &StateT;
108-
fn state_mut(&mut self) -> &mut StateT;
109-
fn subscribers(&self) -> &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>];
110-
fn subscribers_mut(&mut self) -> &mut Vec<Arc<dyn AgentSubscriber<StateT, FwdPropsT>>>;
111-
112-
/// Adds a subscriber to the agent.
113-
fn add_subscriber(&mut self, subscriber: Arc<dyn AgentSubscriber<StateT, FwdPropsT>>) {
114-
self.subscribers_mut().push(subscriber);
115-
}
100+
async fn run(
101+
&self,
102+
input: &RunAgentInput<StateT, FwdPropsT>,
103+
) -> Result<EventStream<'async_trait, StateT>, AgentError>;
116104

117105
/// The main execution method, containing the full pipeline logic.
118106
async fn run_agent(
119-
&mut self,
120-
params: &RunAgentParams<FwdPropsT>,
121-
subscriber: Option<Arc<dyn AgentSubscriber<StateT, FwdPropsT>>>,
107+
&self,
108+
params: &RunAgentParams<StateT, FwdPropsT>,
109+
subscribers: Vec<Arc<dyn AgentSubscriber<StateT, FwdPropsT>>>,
122110
) -> Result<RunAgentResult, AgentError> {
123-
if self.agent_id().is_none() {
124-
*self.agent_id_mut() = Some(AgentId::new());
125-
}
126-
127-
let mut subscribers = self.subscribers().to_vec();
128-
if let Some(sub) = subscriber {
129-
subscribers.push(sub);
130-
}
111+
// TODO: Use Agent ID?
112+
let agent_id = AgentId::random();
113+
114+
let input = RunAgentInput {
115+
thread_id: ThreadId::random(),
116+
run_id: params.run_id.clone().unwrap_or_else(RunId::random),
117+
state: params.state.clone(),
118+
messages: params.messages.clone(),
119+
tools: params.tools.clone().unwrap_or_default(),
120+
context: params.context.clone().unwrap_or_default(),
121+
// TODO: Find suitable default value
122+
forwarded_props: params.forwarded_props.clone().unwrap(),
123+
};
124+
let current_message_ids: HashSet<&MessageId> =
125+
params.messages.iter().map(|m| m.id()).collect();
131126

132-
let input = self.prepare_run_agent_input(params);
133-
let messages = self.messages().to_vec();
134-
let current_message_ids: HashSet<&MessageId> = messages.iter().map(|m| m.id()).collect();
135-
let mut result_val = JsonValue::Null;
127+
// Initialize event handler with the current state
128+
let mut event_handler = EventHandler::new(
129+
params.messages.clone(),
130+
params.state.clone(),
131+
&input,
132+
subscribers,
133+
);
136134

137-
let mut stream = self.run(&input).fuse();
135+
let mut stream = self.run(&input).await?.fuse();
138136

139137
while let Some(event_result) = stream.next().await {
140138
match event_result {
141139
Ok(event) => {
142-
let (mutation, value) = event
143-
.apply_and_process_event(&input, &messages, &input.state, &subscribers)
144-
.await?;
145-
result_val = JsonValue::from(value);
140+
let mutation = event_handler.handle_event(&event).await?;
141+
event_handler.apply_mutation(mutation).await?;
146142
}
147143
Err(e) => {
148-
// self.on_error(&input, &e, &subscribers).await?;
144+
event_handler.on_error(&e).await?;
149145
return Err(e);
150146
}
151147
}
152148
}
153149

154-
// self.on_finalize(&input, &subscribers).await?;
150+
// Finalize the run
151+
event_handler.on_finalize().await?;
155152

156-
let new_messages = self
157-
.messages()
153+
// Collect new messages
154+
let new_messages = event_handler
155+
.messages
158156
.iter()
159157
.filter(|m| !current_message_ids.contains(&m.id()))
160158
.cloned()
161159
.collect();
162160

163161
Ok(RunAgentResult {
164-
result: result_val,
162+
result: event_handler.result,
165163
new_messages,
166164
})
167165
}
168166

169-
/// Helper to construct the input for the `run` method.
170-
fn prepare_run_agent_input(
171-
&self,
172-
params: &RunAgentParams<FwdPropsT>,
173-
) -> RunAgentInput<StateT, FwdPropsT> {
174-
RunAgentInput {
175-
thread_id: self.thread_id().clone(),
176-
run_id: params.run_id.clone().unwrap_or_else(|| RunId::new()),
177-
state: self.state().clone(),
178-
messages: self.messages().to_vec(),
179-
tools: params.tools.clone().unwrap_or_default(),
180-
context: params.context.clone().unwrap_or_default(),
181-
// TODO: Find suitable default value
182-
forwarded_props: params.forwarded_props.clone().unwrap(),
183-
}
184-
}
185-
186-
/// Processes a single event, applying mutations and notifying subscribers.
187-
/// Returns the final result if the event is `Done`.
188-
async fn apply_and_process_event(
189-
&mut self,
190-
event: Event,
191-
input: &RunAgentInput<StateT, FwdPropsT>,
192-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
193-
) -> Result<Option<JsonValue>, AgentError> {
194-
// This is a simplified stand-in for the logic from `defaultApplyEvents` in TS.
195-
// A full implementation would handle each event type to create the correct state mutation.
196-
let (mutation, result) = match event {
197-
Event::RunFinished(e) => {
198-
for sub in subscribers {
199-
sub.on_run_finished(
200-
&e.result.clone().unwrap(),
201-
self.messages(),
202-
self.state(),
203-
input,
204-
)
205-
.await?;
206-
}
207-
(AgentStateMutation::default(), e.result)
208-
}
209-
// In a real implementation, other events like Text, ToolCall, etc.,
210-
// would create mutations to update messages and state.
211-
_ => (AgentStateMutation::default(), None),
212-
};
213-
214-
self.apply_mutation(mutation, input, subscribers).await?;
215-
Ok(result)
216-
}
217-
218-
async fn on_initialize(
219-
&mut self,
220-
input: &mut RunAgentInput<StateT, FwdPropsT>,
221-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
222-
) -> Result<(), AgentError> {
223-
for subscriber in subscribers {
224-
let mutation = subscriber
225-
.on_run_initialized(self.messages(), self.state(), input)
226-
.await?;
227-
228-
if mutation.messages.is_some() || mutation.state.is_some() {
229-
if let Some(ref messages) = mutation.messages {
230-
input.messages = messages.clone();
231-
}
232-
if let Some(ref state) = mutation.state {
233-
input.state = state.clone();
234-
}
235-
self.apply_mutation(mutation, input, subscribers).await?;
236-
}
237-
}
238-
Ok(())
239-
}
240-
241-
async fn on_error(
242-
&mut self,
243-
input: &RunAgentInput<StateT, FwdPropsT>,
244-
error: &AgentError,
245-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
246-
) -> Result<(), AgentError> {
247-
for subscriber in subscribers {
248-
let mutation = subscriber
249-
.on_run_failed(error, self.messages(), self.state(), input)
250-
.await?;
251-
252-
self.apply_mutation(mutation, input, subscribers).await?;
253-
}
254-
Ok(())
255-
}
256-
257-
async fn on_finalize(
258-
&mut self,
259-
input: &RunAgentInput<StateT, FwdPropsT>,
260-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
261-
) -> Result<(), AgentError> {
262-
for subscriber in subscribers {
263-
let mutation = subscriber
264-
.on_run_finalized(self.messages(), self.state(), input)
265-
.await?;
266-
267-
self.apply_mutation(mutation, input, subscribers).await?;
268-
}
269-
Ok(())
270-
}
271-
272-
async fn apply_mutation(
273-
&mut self,
274-
mutation: AgentStateMutation<StateT>,
275-
input: &RunAgentInput<StateT, FwdPropsT>,
276-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
277-
) -> Result<(), AgentError> {
278-
if let Some(messages) = mutation.messages {
279-
*self.messages_mut() = messages;
280-
self.notify_messages_changed(input, subscribers).await?;
281-
}
282-
283-
if let Some(state) = mutation.state {
284-
*self.state_mut() = state;
285-
self.notify_state_changed(input, subscribers).await?;
286-
}
287-
288-
Ok(())
289-
}
290-
291-
async fn notify_messages_changed(
292-
&self,
293-
input: &RunAgentInput<StateT, FwdPropsT>,
294-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
295-
) -> Result<(), AgentError> {
296-
for subscriber in subscribers {
297-
subscriber
298-
.on_messages_changed(self.messages(), self.state(), input)
299-
.await?;
300-
}
301-
Ok(())
302-
}
303-
304-
async fn notify_state_changed(
305-
&self,
306-
input: &RunAgentInput<StateT, FwdPropsT>,
307-
subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
308-
) -> Result<(), AgentError> {
309-
for subscriber in subscribers {
310-
subscriber
311-
.on_state_changed(self.messages(), self.state(), input)
312-
.await?;
313-
}
314-
Ok(())
315-
}
167+
// Helper function to run subscribers that can return a mutation
168+
// async fn run_subscribers_with_mutation<F, Fut>(
169+
// &self,
170+
// subscribers: &[Arc<dyn AgentSubscriber<StateT, FwdPropsT>>],
171+
// mut callback: F,
172+
// ) -> Result<AgentStateMutation<StateT>, AgentError>
173+
// where
174+
// F: FnMut(&Arc<dyn AgentSubscriber<StateT, FwdPropsT>>) -> Fut + Send,
175+
// Fut: std::future::Future<Output = Result<AgentStateMutation<StateT>, AgentError>>,
176+
// {
177+
// let mut result = AgentStateMutation::default();
178+
//
179+
// for subscriber in subscribers {
180+
// let mutation = callback(subscriber).await?;
181+
//
182+
// if mutation.messages.is_some() {
183+
// result.messages = mutation.messages;
184+
// }
185+
//
186+
// if mutation.state.is_some() {
187+
// result.state = mutation.state;
188+
// }
189+
//
190+
// if mutation.stop_propagation {
191+
// result.stop_propagation = true;
192+
// break;
193+
// }
194+
// }
195+
//
196+
// Ok(result)
197+
// }
316198
}

0 commit comments

Comments
 (0)