Skip to content

Commit 494a4a6

Browse files
64bittinco
andauthored
feat: better streaming errors (#445)
* Unify StreamError and OpenAIError (#413) * unify StreamError and OpenAIError * format * clippy * use underlying reqwest_eventsource::Error * UnknownEvent * update exampels to test streaming errors * update responses-stream example --------- Co-authored-by: Tinco Andringa <[email protected]>
1 parent 3d0a137 commit 494a4a6

File tree

8 files changed

+92
-62
lines changed

8 files changed

+92
-62
lines changed

async-openai/src/client.rs

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use std::pin::Pin;
22

33
use bytes::Bytes;
44
use futures::{stream::StreamExt, Stream};
5-
use reqwest::multipart::Form;
6-
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
5+
use reqwest::{multipart::Form, Response};
6+
use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
77
use serde::{de::DeserializeOwned, Serialize};
88

99
use crate::{
1010
config::{Config, OpenAIConfig},
11-
error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
11+
error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
1212
file::Files,
1313
image::Images,
1414
moderation::Moderations,
@@ -335,52 +335,34 @@ impl<C: Config> Client<C> {
335335
.map_err(backoff::Error::Permanent)?;
336336

337337
let status = response.status();
338-
let bytes = response
339-
.bytes()
340-
.await
341-
.map_err(OpenAIError::Reqwest)
342-
.map_err(backoff::Error::Permanent)?;
343-
344-
if status.is_server_error() {
345-
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
346-
let message: String = String::from_utf8_lossy(&bytes).into_owned();
347-
tracing::warn!("Server error: {status} - {message}");
348-
return Err(backoff::Error::Transient {
349-
err: OpenAIError::ApiError(ApiError {
350-
message,
351-
r#type: None,
352-
param: None,
353-
code: None,
354-
}),
355-
retry_after: None,
356-
});
357-
}
358338

359-
// Deserialize response body from either error object or actual response object
360-
if !status.is_success() {
361-
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
362-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
363-
.map_err(backoff::Error::Permanent)?;
364-
365-
if status.as_u16() == 429
366-
// API returns 429 also when:
367-
// "You exceeded your current quota, please check your plan and billing details."
368-
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
369-
{
370-
// Rate limited retry...
371-
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
372-
return Err(backoff::Error::Transient {
373-
err: OpenAIError::ApiError(wrapped_error.error),
374-
retry_after: None,
375-
});
376-
} else {
377-
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
378-
wrapped_error.error,
379-
)));
339+
match read_response(response).await {
340+
Ok(bytes) => Ok(bytes),
341+
Err(e) => {
342+
match e {
343+
OpenAIError::ApiError(api_error) => {
344+
if status.is_server_error() {
345+
Err(backoff::Error::Transient {
346+
err: OpenAIError::ApiError(api_error),
347+
retry_after: None,
348+
})
349+
} else if status.as_u16() == 429
350+
&& api_error.r#type != Some("insufficient_quota".to_string())
351+
{
352+
// Rate limited retry...
353+
tracing::warn!("Rate limited: {}", api_error.message);
354+
Err(backoff::Error::Transient {
355+
err: OpenAIError::ApiError(api_error),
356+
retry_after: None,
357+
})
358+
} else {
359+
Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
360+
}
361+
}
362+
_ => Err(backoff::Error::Permanent(e)),
363+
}
380364
}
381365
}
382-
383-
Ok(bytes)
384366
})
385367
.await
386368
}
@@ -471,6 +453,44 @@ impl<C: Config> Client<C> {
471453
}
472454
}
473455

456+
async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
457+
let status = response.status();
458+
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
459+
460+
if status.is_server_error() {
461+
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
462+
let message: String = String::from_utf8_lossy(&bytes).into_owned();
463+
tracing::warn!("Server error: {status} - {message}");
464+
return Err(OpenAIError::ApiError(ApiError {
465+
message,
466+
r#type: None,
467+
param: None,
468+
code: None,
469+
}));
470+
}
471+
472+
// Deserialize response body from either error object or actual response object
473+
if !status.is_success() {
474+
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
475+
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
476+
477+
return Err(OpenAIError::ApiError(wrapped_error.error));
478+
}
479+
480+
Ok(bytes)
481+
}
482+
483+
async fn map_stream_error(value: EventSourceError) -> OpenAIError {
484+
match value {
485+
EventSourceError::InvalidStatusCode(status_code, response) => {
486+
read_response(response).await.expect_err(&format!(
487+
"Unreachable because read_response returns err when status_code {status_code} is invalid"
488+
))
489+
}
490+
_ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value.into())),
491+
}
492+
}
493+
474494
/// Request which responds with SSE.
475495
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
476496
pub(crate) async fn stream<O>(
@@ -485,7 +505,7 @@ where
485505
while let Some(ev) = event_source.next().await {
486506
match ev {
487507
Err(e) => {
488-
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
508+
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
489509
// rx dropped
490510
break;
491511
}
@@ -530,7 +550,7 @@ where
530550
while let Some(ev) = event_source.next().await {
531551
match ev {
532552
Err(e) => {
533-
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
553+
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
534554
// rx dropped
535555
break;
536556
}

async-openai/src/error.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
2+
use std::string::FromUtf8Error;
3+
4+
use reqwest::{header::HeaderValue, Response};
25
use serde::{Deserialize, Serialize};
36

47
#[derive(Debug, thiserror::Error)]
@@ -20,13 +23,23 @@ pub enum OpenAIError {
2023
FileReadError(String),
2124
/// Error on SSE streaming
2225
#[error("stream failed: {0}")]
23-
StreamError(String),
26+
StreamError(StreamError),
2427
/// Error from client side validation
2528
/// or when builder fails to build request before making API call
2629
#[error("invalid args: {0}")]
2730
InvalidArgument(String),
2831
}
2932

33+
#[derive(Debug, thiserror::Error)]
34+
pub enum StreamError {
35+
/// Underlying error from reqwest_eventsource library when reading the stream
36+
#[error("{0}")]
37+
ReqwestEventSource(#[from] reqwest_eventsource::Error),
38+
/// Error when a stream event does not match one of the expected values
39+
#[error("Unknown event: {0:#?}")]
40+
UnknownEvent(eventsource_stream::Event),
41+
}
42+
3043
/// OpenAI API returns error object on failure
3144
#[derive(Debug, Serialize, Deserialize, Clone)]
3245
pub struct ApiError {

async-openai/src/types/assistant_stream.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::pin::Pin;
33
use futures::Stream;
44
use serde::Deserialize;
55

6-
use crate::error::{map_deserialization_error, ApiError, OpenAIError};
6+
use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError};
77

88
use super::{
99
MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
@@ -207,9 +207,7 @@ impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
207207
.map(AssistantStreamEvent::ErrorEvent),
208208
"done" => Ok(AssistantStreamEvent::Done(value.data)),
209209

210-
_ => Err(OpenAIError::StreamError(
211-
"Unrecognized event: {value:?#}".into(),
212-
)),
210+
_ => Err(OpenAIError::StreamError(StreamError::UnknownEvent(value))),
213211
}
214212
}
215213
}

examples/chat-stream/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
4040
});
4141
}
4242
Err(err) => {
43-
writeln!(lock, "error: {err}").unwrap();
43+
writeln!(lock, "error: {err:?}").unwrap();
4444
}
4545
}
4646
stdout().flush()?;

examples/completions-stream/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2020
Ok(ccr) => ccr.choices.iter().for_each(|c| {
2121
print!("{}", c.text);
2222
}),
23-
Err(e) => eprintln!("{}", e),
23+
Err(e) => eprintln!("{e:?}"),
2424
}
2525
}
2626

examples/function-call-stream/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
7474
}
7575
}
7676
Err(err) => {
77-
writeln!(lock, "error: {err}").unwrap();
77+
writeln!(lock, "error: {err:?}").unwrap();
7878
}
7979
}
8080
stdout().flush()?;
@@ -132,7 +132,7 @@ async fn call_fn(
132132
});
133133
}
134134
Err(err) => {
135-
writeln!(lock, "error: {err}").unwrap();
135+
writeln!(lock, "error: {err:?}").unwrap();
136136
}
137137
}
138138
stdout().flush()?;

examples/responses-stream/src/main.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
3636
| ResponseEvent::ResponseFailed(_) => {
3737
break;
3838
}
39-
_ => { println!("{response_event:#?}"); }
39+
_ => {
40+
println!("{response_event:#?}");
41+
}
4042
},
4143
Err(e) => {
4244
eprintln!("{e:#?}");
43-
// When a stream ends, it returns Err(OpenAIError::StreamError("Stream ended"))
44-
// Without this, the stream will never end
45-
break;
4645
}
4746
}
4847
}

examples/tool-call-stream/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
206206
}
207207
Err(err) => {
208208
let mut lock = stdout().lock();
209-
writeln!(lock, "error: {err}").unwrap();
209+
writeln!(lock, "error: {err:?}").unwrap();
210210
}
211211
}
212212
stdout()

0 commit comments

Comments
 (0)