Skip to content

Commit c6eedb4

Browse files
authored
JSON RPC Constructors (#4)
* Introduce structured error handling Refactor error creation to use predefined `ErrorCode` constants. This allows for consistent error codes and messages, and provides a mechanism to include additional error details via the new `ErrorData` struct. Replaces hardcoded error numbers and messages with the `ErrorCode` system. * Refactor RPC to use custom `Error` type Replaces `anyhow::Result` with `Result<T, Error>` for Agent Connection Protocol (ACP) operations and schema definitions. This change enhances error handling by mapping failures to specific `ErrorCode` variants, improving diagnostic capabilities and type safety. The `response_from_any` trait method now returns `Result` for more precise error propagation. * Remove error code and simplify error creation * Add TS helpers
1 parent b4eb7c7 commit c6eedb4

File tree

3 files changed

+181
-74
lines changed

3 files changed

+181
-74
lines changed

rust/acp.rs

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
mod acp_tests;
33
mod schema;
44

5-
use anyhow::{Result, anyhow};
65
use futures::{
76
AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
87
StreamExt as _,
@@ -42,7 +41,7 @@ impl AgentConnection {
4241
outgoing_bytes: impl Unpin + AsyncWrite,
4342
incoming_bytes: impl Unpin + AsyncRead,
4443
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
45-
) -> (Self, impl Future<Output = Result<()>>) {
44+
) -> (Self, impl Future<Output = Result<(), Error>>) {
4645
let handler = Arc::new(handler);
4746
let (connection, io_task) = Connection::new(
4847
Box::new(move |request| {
@@ -60,17 +59,12 @@ impl AgentConnection {
6059
pub fn request<R: AgentRequest + 'static>(
6160
&self,
6261
params: R,
63-
) -> impl Future<Output = Result<R::Response>> {
62+
) -> impl Future<Output = Result<R::Response, Error>> {
6463
let params = params.into_any();
6564
let result = self.0.request(params.method_name(), params);
6665
async move {
6766
let result = result.await?;
68-
R::response_from_any(result).ok_or_else(|| {
69-
anyhow!(crate::Error {
70-
code: -32700,
71-
message: "Unexpected Response".to_string(),
72-
})
73-
})
67+
R::response_from_any(result)
7468
}
7569
}
7670
}
@@ -81,7 +75,7 @@ impl ClientConnection {
8175
outgoing_bytes: impl Unpin + AsyncWrite,
8276
incoming_bytes: impl Unpin + AsyncRead,
8377
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
84-
) -> (Self, impl Future<Output = Result<()>>) {
78+
) -> (Self, impl Future<Output = Result<(), Error>>) {
8579
let handler = Arc::new(handler);
8680
let (connection, io_task) = Connection::new(
8781
Box::new(move |request| {
@@ -98,17 +92,12 @@ impl ClientConnection {
9892
pub fn request<R: ClientRequest>(
9993
&self,
10094
params: R,
101-
) -> impl use<R> + Future<Output = Result<R::Response>> {
95+
) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
10296
let params = params.into_any();
10397
let result = self.0.request(params.method_name(), params);
10498
async move {
10599
let result = result.await?;
106-
R::response_from_any(result).ok_or_else(|| {
107-
anyhow!(Error {
108-
code: -32700,
109-
message: "Could not parse".to_string(),
110-
})
111-
})
100+
R::response_from_any(result)
112101
}
113102
}
114103
}
@@ -124,7 +113,7 @@ where
124113
}
125114

126115
type ResponseSenders<T> =
127-
Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, crate::Error>>)>>>;
116+
Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
128117

129118
#[derive(Debug, Deserialize)]
130119
struct IncomingMessage<'a> {
@@ -157,15 +146,76 @@ enum OutgoingMessage<Req, Resp> {
157146
pub struct Error {
158147
pub code: i32,
159148
pub message: String,
149+
#[serde(skip_serializing_if = "Option::is_none")]
150+
pub data: Option<ErrorData>,
151+
}
152+
153+
impl Error {
154+
pub fn new(code: i32, message: impl Into<String>) -> Self {
155+
Error {
156+
code,
157+
message: message.into(),
158+
data: None,
159+
}
160+
}
161+
162+
pub fn with_details(mut self, details: impl Into<String>) -> Self {
163+
self.data = Some(ErrorData::new(details));
164+
self
165+
}
166+
167+
/// Invalid JSON was received by the server. An error occurred on the server while parsing the JSON text.
168+
pub fn parse_error() -> Self {
169+
Error::new(-32700, "Parse error")
170+
}
171+
172+
/// The JSON sent is not a valid Request object.
173+
pub fn invalid_request() -> Self {
174+
Error::new(-32600, "Invalid Request")
175+
}
176+
177+
/// The method does not exist / is not available.
178+
pub fn method_not_found() -> Self {
179+
Error::new(-32601, "Method not found")
180+
}
181+
182+
/// Invalid method parameter(s).
183+
pub fn invalid_params() -> Self {
184+
Error::new(-32602, "Invalid params")
185+
}
186+
187+
/// Internal JSON-RPC error.
188+
pub fn internal_error() -> Self {
189+
Error::new(-32603, "Internal error")
190+
}
160191
}
161192

162193
impl std::error::Error for Error {}
163194
impl Display for Error {
164195
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165196
if self.message.is_empty() {
166-
write!(f, "{}", self.code)
197+
write!(f, "{}", self.code)?;
167198
} else {
168-
write!(f, "{}", self.message)
199+
write!(f, "{}", self.message)?;
200+
}
201+
202+
if let Some(data) = &self.data {
203+
write!(f, ": {}", data.details)?;
204+
}
205+
206+
Ok(())
207+
}
208+
}
209+
210+
#[derive(Debug, Clone, Serialize, Deserialize)]
211+
pub struct ErrorData {
212+
pub details: String,
213+
}
214+
215+
impl ErrorData {
216+
pub fn new(details: impl Into<String>) -> Self {
217+
ErrorData {
218+
details: details.into(),
169219
}
170220
}
171221
}
@@ -177,17 +227,20 @@ pub struct JsonRpcMessage<Req, Resp> {
177227
message: OutgoingMessage<Req, Resp>,
178228
}
179229

230+
type ResponseHandler<In, Resp> =
231+
Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
232+
180233
impl<In, Out> Connection<In, Out>
181234
where
182235
In: AnyRequest,
183236
Out: AnyRequest,
184237
{
185238
fn new(
186-
request_handler: Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>>,
239+
request_handler: ResponseHandler<In, In::Response>,
187240
outgoing_bytes: impl Unpin + AsyncWrite,
188241
incoming_bytes: impl Unpin + AsyncRead,
189242
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
190-
) -> (Self, impl Future<Output = Result<()>>) {
243+
) -> (Self, impl Future<Output = Result<(), Error>>) {
191244
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
192245
let (incoming_tx, incoming_rx) = mpsc::unbounded();
193246
let this = Self {
@@ -210,7 +263,7 @@ where
210263
&self,
211264
method: &'static str,
212265
params: Out,
213-
) -> impl use<In, Out> + Future<Output = Result<Out::Response>> {
266+
) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
214267
let (tx, rx) = oneshot::channel();
215268
let id = self.next_id.fetch_add(1, SeqCst);
216269
self.response_senders.lock().insert(id, (method, tx));
@@ -225,7 +278,10 @@ where
225278
{
226279
self.response_senders.lock().remove(&id);
227280
}
228-
async move { rx.await?.map_err(|e| anyhow!(e)) }
281+
async move {
282+
rx.await
283+
.map_err(|e| Error::internal_error().with_details(e.to_string()))?
284+
}
229285
}
230286

231287
async fn handle_io(
@@ -234,7 +290,7 @@ where
234290
response_senders: ResponseSenders<Out::Response>,
235291
mut outgoing_bytes: impl Unpin + AsyncWrite,
236292
incoming_bytes: impl Unpin + AsyncRead,
237-
) -> Result<()> {
293+
) -> Result<(), Error> {
238294
let mut output_reader = BufReader::new(incoming_bytes);
239295
let mut outgoing_line = Vec::new();
240296
let mut incoming_line = String::new();
@@ -243,7 +299,8 @@ where
243299
message = outgoing_rx.next() => {
244300
if let Some(message) = message {
245301
outgoing_line.clear();
246-
serde_json::to_writer(&mut outgoing_line, &message)?;
302+
serde_json::to_writer(&mut outgoing_line, &message).map_err(|e| Error::internal_error()
303+
.with_details(e.to_string()))?;
247304
log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
248305
outgoing_line.push(b'\n');
249306
outgoing_bytes.write_all(&outgoing_line).await.ok();
@@ -252,7 +309,7 @@ where
252309
}
253310
}
254311
bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
255-
if bytes_read? == 0 {
312+
if bytes_read.map_err(|e| Error::internal_error().with_details(e.to_string()))? == 0 {
256313
break
257314
}
258315
log::trace!("recv: {}", &incoming_line);
@@ -300,9 +357,7 @@ where
300357
fn handle_incoming(
301358
outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
302359
mut incoming_rx: UnboundedReceiver<(i32, In)>,
303-
incoming_handler: Box<
304-
dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>,
305-
>,
360+
incoming_handler: ResponseHandler<In, In::Response>,
306361
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
307362
) {
308363
let spawn = Rc::new(spawn);
@@ -325,10 +380,8 @@ where
325380
outgoing_tx
326381
.unbounded_send(OutgoingMessage::ErrorResponse {
327382
id,
328-
error: Error {
329-
code: -32603,
330-
message: error.to_string(),
331-
},
383+
error: Error::internal_error()
384+
.with_details(error.to_string()),
332385
})
333386
.ok();
334387
}

0 commit comments

Comments
 (0)