Skip to content

Commit 232de71

Browse files
64bitifsheldon
authored andcommitted
feat: better streaming errors (64bit#445)
* Unify StreamError and OpenAIError (64bit#413) * unify StreamError and OpenAIError * format * clippy * use underlying reqwest_eventsource::Error * UnknownEvent * update exampels to test streaming errors * update responses-stream example --------- Co-authored-by: Tinco Andringa <[email protected]> (cherry picked from commit 494a4a6) # Conflicts: # async-openai-wasm/src/types/assistant_stream.rs # async-openai/src/client.rs # examples/chat-stream/src/main.rs # examples/completions-stream/src/main.rs # examples/function-call-stream/src/main.rs # examples/responses-stream/src/main.rs # examples/tool-call-stream/src/main.rs
1 parent b47a57b commit 232de71

File tree

3 files changed

+58
-74
lines changed

3 files changed

+58
-74
lines changed

async-openai-wasm/src/client.rs

Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ use future::Future;
88
use futures::stream::Filter;
99
use futures::{Stream, stream::StreamExt};
1010
use pin_project::pin_project;
11-
use reqwest::multipart::Form;
11+
use reqwest::{Response, multipart::Form};
1212
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
1313
use serde::{Serialize, de::DeserializeOwned};
1414

15+
use crate::error::{ApiError, StreamError};
1516
use crate::{
1617
Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
1718
Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
@@ -315,27 +316,21 @@ impl<C: Config> Client<C> {
315316
.map_err(OpenAIError::Reqwest)?;
316317

317318
let status = response.status();
318-
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
319-
320-
// Deserialize response body from either error object or actual response object
321-
if !status.is_success() {
322-
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
323-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
324-
325-
if status.as_u16() == 429
326-
// API returns 429 also when:
327-
// "You exceeded your current quota, please check your plan and billing details."
328-
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
329-
{
330-
// Rate limited retry...
331-
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
332-
return Err(OpenAIError::ApiError(wrapped_error.error));
333-
} else {
334-
return Err(OpenAIError::ApiError(wrapped_error.error));
335-
}
319+
match read_response(response).await {
320+
Ok(bytes) => Ok(bytes),
321+
Err(e) => match e {
322+
OpenAIError::ApiError(api_error) => {
323+
if status.as_u16() == 429
324+
&& api_error.r#type != Some("insufficient_quota".to_string())
325+
{
326+
// Rate limited retry...
327+
tracing::warn!("Rate limited: {}", api_error.message);
328+
}
329+
Err(OpenAIError::ApiError(api_error))
330+
}
331+
_ => Err(e),
332+
},
336333
}
337-
338-
Ok(bytes)
339334
}
340335

341336
/// Execute a HTTP request
@@ -514,7 +509,9 @@ where
514509
},
515510
Err(e) => {
516511
*this.done = true;
517-
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
512+
Poll::Ready(Some(Err(OpenAIError::StreamError(
513+
StreamError::ReqwestEventSource(e),
514+
))))
518515
}
519516
},
520517
}
@@ -524,50 +521,29 @@ where
524521
}
525522
}
526523

527-
// pub(crate) async fn stream_mapped_raw_events<O>(
528-
// mut event_source: EventSource,
529-
// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
530-
// ) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
531-
// where
532-
// O: DeserializeOwned + std::marker::Send + 'static,
533-
// {
534-
// let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
535-
//
536-
// tokio::spawn(async move {
537-
// while let Some(ev) = event_source.next().await {
538-
// match ev {
539-
// Err(e) => {
540-
// if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
541-
// // rx dropped
542-
// break;
543-
// }
544-
// }
545-
// Ok(event) => match event {
546-
// Event::Message(message) => {
547-
// let mut done = false;
548-
//
549-
// if message.data == "[DONE]" {
550-
// done = true;
551-
// }
552-
//
553-
// let response = event_mapper(message);
554-
//
555-
// if let Err(_e) = tx.send(response) {
556-
// // rx dropped
557-
// break;
558-
// }
559-
//
560-
// if done {
561-
// break;
562-
// }
563-
// }
564-
// Event::Open => continue,
565-
// },
566-
// }
567-
// }
568-
//
569-
// event_source.close();
570-
// });
571-
//
572-
// Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
573-
// }
524+
async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
525+
let status = response.status();
526+
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
527+
528+
if status.is_server_error() {
529+
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
530+
let message: String = String::from_utf8_lossy(&bytes).into_owned();
531+
tracing::warn!("Server error: {status} - {message}");
532+
return Err(OpenAIError::ApiError(ApiError {
533+
message,
534+
r#type: None,
535+
param: None,
536+
code: None,
537+
}));
538+
}
539+
540+
// Deserialize response body from either error object or actual response object
541+
if !status.is_success() {
542+
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
543+
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
544+
545+
return Err(OpenAIError::ApiError(wrapped_error.error));
546+
}
547+
548+
Ok(bytes)
549+
}

async-openai-wasm/src/error.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,23 @@ pub enum OpenAIError {
2020
FileReadError(String),
2121
/// Error on SSE streaming
2222
#[error("stream failed: {0}")]
23-
StreamError(String),
23+
StreamError(StreamError),
2424
/// Error from client side validation
2525
/// or when builder fails to build request before making API call
2626
#[error("invalid args: {0}")]
2727
InvalidArgument(String),
2828
}
2929

30+
#[derive(Debug, thiserror::Error)]
31+
pub enum StreamError {
32+
/// Underlying error from reqwest_eventsource library when reading the stream
33+
#[error("{0}")]
34+
ReqwestEventSource(#[from] reqwest_eventsource::Error),
35+
/// Error when a stream event does not match one of the expected values
36+
#[error("Unknown event: {0:#?}")]
37+
UnknownEvent(eventsource_stream::Event),
38+
}
39+
3040
/// OpenAI API returns error object on failure
3141
#[derive(Debug, Serialize, Deserialize, Clone)]
3242
pub struct ApiError {

async-openai-wasm/src/types/assistant_stream.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use serde::Deserialize;
22

33
use crate::client::OpenAIEventStream;
4-
use crate::error::{ApiError, OpenAIError, map_deserialization_error};
4+
use crate::error::{ApiError, OpenAIError, StreamError, map_deserialization_error};
55

66
use super::{
77
MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
@@ -203,9 +203,7 @@ impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
203203
.map(AssistantStreamEvent::ErrorEvent),
204204
"done" => Ok(AssistantStreamEvent::Done(value.data)),
205205

206-
_ => Err(OpenAIError::StreamError(
207-
"Unrecognized event: {value:?#}".into(),
208-
)),
206+
_ => Err(OpenAIError::StreamError(StreamError::UnknownEvent(value))),
209207
}
210208
}
211209
}

0 commit comments

Comments
 (0)