@@ -8,10 +8,11 @@ use future::Future;
88use futures:: stream:: Filter ;
99use futures:: { Stream , stream:: StreamExt } ;
1010use pin_project:: pin_project;
11- use reqwest:: multipart:: Form ;
11+ use reqwest:: { Response , multipart:: Form } ;
1212use reqwest_eventsource:: { Event , EventSource , RequestBuilderExt } ;
1313use serde:: { Serialize , de:: DeserializeOwned } ;
1414
15+ use crate :: error:: { ApiError , StreamError } ;
1516use crate :: {
1617 Assistants , Audio , AuditLogs , Batches , Chat , Completions , Embeddings , FineTuning , Invites ,
1718 Models , Projects , Responses , Threads , Uploads , Users , VectorStores ,
@@ -315,27 +316,21 @@ impl<C: Config> Client<C> {
315316 . map_err ( OpenAIError :: Reqwest ) ?;
316317
317318 let status = response. status ( ) ;
318- let bytes = response. bytes ( ) . await . map_err ( OpenAIError :: Reqwest ) ?;
319-
320- // Deserialize response body from either error object or actual response object
321- if !status. is_success ( ) {
322- let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
323- . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) ) ?;
324-
325- if status. as_u16 ( ) == 429
326- // API returns 429 also when:
327- // "You exceeded your current quota, please check your plan and billing details."
328- && wrapped_error. error . r#type != Some ( "insufficient_quota" . to_string ( ) )
329- {
330- // Rate limited retry...
331- tracing:: warn!( "Rate limited: {}" , wrapped_error. error. message) ;
332- return Err ( OpenAIError :: ApiError ( wrapped_error. error ) ) ;
333- } else {
334- return Err ( OpenAIError :: ApiError ( wrapped_error. error ) ) ;
335- }
319+ match read_response ( response) . await {
320+ Ok ( bytes) => Ok ( bytes) ,
321+ Err ( e) => match e {
322+ OpenAIError :: ApiError ( api_error) => {
323+ if status. as_u16 ( ) == 429
324+ && api_error. r#type != Some ( "insufficient_quota" . to_string ( ) )
325+ {
326+ // Rate limited retry...
327+ tracing:: warn!( "Rate limited: {}" , api_error. message) ;
328+ }
329+ Err ( OpenAIError :: ApiError ( api_error) )
330+ }
331+ _ => Err ( e) ,
332+ } ,
336333 }
337-
338- Ok ( bytes)
339334 }
340335
341336 /// Execute a HTTP request
@@ -514,7 +509,9 @@ where
514509 } ,
515510 Err ( e) => {
516511 * this. done = true ;
517- Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
512+ Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError (
513+ StreamError :: ReqwestEventSource ( e) ,
514+ ) ) ) )
518515 }
519516 } ,
520517 }
@@ -524,50 +521,29 @@ where
524521 }
525522}
526523
527- // pub(crate) async fn stream_mapped_raw_events<O>(
528- // mut event_source: EventSource,
529- // event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
530- // ) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
531- // where
532- // O: DeserializeOwned + std::marker::Send + 'static,
533- // {
534- // let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
535- //
536- // tokio::spawn(async move {
537- // while let Some(ev) = event_source.next().await {
538- // match ev {
539- // Err(e) => {
540- // if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
541- // // rx dropped
542- // break;
543- // }
544- // }
545- // Ok(event) => match event {
546- // Event::Message(message) => {
547- // let mut done = false;
548- //
549- // if message.data == "[DONE]" {
550- // done = true;
551- // }
552- //
553- // let response = event_mapper(message);
554- //
555- // if let Err(_e) = tx.send(response) {
556- // // rx dropped
557- // break;
558- // }
559- //
560- // if done {
561- // break;
562- // }
563- // }
564- // Event::Open => continue,
565- // },
566- // }
567- // }
568- //
569- // event_source.close();
570- // });
571- //
572- // Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
573- // }
524+ async fn read_response ( response : Response ) -> Result < Bytes , OpenAIError > {
525+ let status = response. status ( ) ;
526+ let bytes = response. bytes ( ) . await . map_err ( OpenAIError :: Reqwest ) ?;
527+
528+ if status. is_server_error ( ) {
529+ // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
530+ let message: String = String :: from_utf8_lossy ( & bytes) . into_owned ( ) ;
531+ tracing:: warn!( "Server error: {status} - {message}" ) ;
532+ return Err ( OpenAIError :: ApiError ( ApiError {
533+ message,
534+ r#type : None ,
535+ param : None ,
536+ code : None ,
537+ } ) ) ;
538+ }
539+
540+ // Deserialize response body from either error object or actual response object
541+ if !status. is_success ( ) {
542+ let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
543+ . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) ) ?;
544+
545+ return Err ( OpenAIError :: ApiError ( wrapped_error. error ) ) ;
546+ }
547+
548+ Ok ( bytes)
549+ }
0 commit comments