diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 99d8da8b..c08626cf 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -28,6 +28,7 @@ byot = [] [dependencies] async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" } +async-trait = "0.1" backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.22.1" futures = "0.3.31" @@ -40,7 +41,9 @@ reqwest = { version = "0.12.12", features = [ reqwest-eventsource = "0.6.0" serde = { version = "1.0.217", features = ["derive", "rc"] } serde_json = "1.0.135" +serde_urlencoded = "0.7" thiserror = "2.0.11" +uuid = { version = "1.11", features = ["v4"] } tokio = { version = "1.43.0", features = ["fs", "macros"] } tokio-stream = "0.1.17" tokio-util = { version = "0.7.13", features = ["codec", "io-util"] } diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fe2ed232..38bdabba 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -1,8 +1,10 @@ use std::pin::Pin; +use std::sync::Arc; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use reqwest::multipart::Form; +use reqwest::{Method, Url}; use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; @@ -10,6 +12,7 @@ use crate::{ config::{Config, OpenAIConfig}, error::{map_deserialization_error, ApiError, OpenAIError, WrappedError}, file::Files, + http_client::{BoxedHttpClient, HttpClient, MultipartForm, SseEvent}, image::Images, moderation::Moderations, traits::AsyncTryFrom, @@ -17,11 +20,11 @@ use crate::{ Models, Projects, Responses, Threads, Uploads, Users, VectorStores, }; -#[derive(Debug, Clone, Default)] +#[derive(Clone)] /// Client is a container for config, backoff and http_client /// used to make API calls. pub struct Client { - http_client: reqwest::Client, + http_client: BoxedHttpClient, config: C, backoff: backoff::ExponentialBackoff, } @@ -29,19 +32,19 @@ pub struct Client { impl Client { /// Client with default [OpenAIConfig] pub fn new() -> Self { - Self::default() + Self::with_config(OpenAIConfig::default()) } } impl Client { /// Create client with a custom HTTP client, OpenAI config, and backoff. pub fn build( - http_client: reqwest::Client, + http_client: impl HttpClient + 'static, config: C, backoff: backoff::ExponentialBackoff, ) -> Self { Self { - http_client, + http_client: Arc::new(http_client), config, backoff, } @@ -50,17 +53,16 @@ impl Client { /// Create client with [OpenAIConfig] or [crate::config::AzureConfig] pub fn with_config(config: C) -> Self { Self { - http_client: reqwest::Client::new(), + http_client: Arc::new(reqwest::Client::new()), config, backoff: Default::default(), } } - /// Provide your own [client] to make HTTP requests with. - /// - /// [client]: reqwest::Client - pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self { - self.http_client = http_client; + /// Provide your own HTTP client implementation to make requests with. + /// This can be reqwest::Client, ClientWithMiddleware, or any custom implementation. + pub fn with_http_client(mut self, http_client: impl HttpClient + 'static) -> Self { + self.http_client = Arc::new(http_client); self } @@ -176,16 +178,10 @@ impl Client { where O: DeserializeOwned, { - let request_maker = || async { - Ok(self - .http_client - .get(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .build()?) - }; - - self.execute(request_maker).await + let bytes = self.execute_with_body(Method::GET, path, None).await?; + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + Ok(response) } /// Make a GET request to {path} with given Query and deserialize the response body @@ -194,17 +190,24 @@ impl Client { O: DeserializeOwned, Q: Serialize + ?Sized, { - let request_maker = || async { - Ok(self - .http_client - .get(self.config.url(path)) - .query(&self.config.query()) - .query(query) - .headers(self.config.headers()) - .build()?) + // Build path with additional query parameters + let query_string = serde_urlencoded::to_string(query).map_err(|e| { + OpenAIError::InvalidArgument(format!("Failed to serialize query: {}", e)) + })?; + let path_with_query = if query_string.is_empty() { + path.to_string() + } else if path.contains('?') { + format!("{}&{}", path, query_string) + } else { + format!("{}?{}", path, query_string) }; - self.execute(request_maker).await + let bytes = self + .execute_with_body(Method::GET, &path_with_query, None) + .await?; + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + Ok(response) } /// Make a DELETE request to {path} and deserialize the response body @@ -212,30 +215,15 @@ impl Client { where O: DeserializeOwned, { - let request_maker = || async { - Ok(self - .http_client - .delete(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .build()?) - }; - - self.execute(request_maker).await + let bytes = self.execute_with_body(Method::DELETE, path, None).await?; + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + Ok(response) } /// Make a GET request to {path} and return the response body pub(crate) async fn get_raw(&self, path: &str) -> Result { - let request_maker = || async { - Ok(self - .http_client - .get(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .build()?) - }; - - self.execute_raw(request_maker).await + self.execute_with_body(Method::GET, path, None).await } /// Make a POST request to {path} and return the response body @@ -243,17 +231,11 @@ impl Client { where I: Serialize, { - let request_maker = || async { - Ok(self - .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .json(&request) - .build()?) - }; - - self.execute_raw(request_maker).await + let body = serde_json::to_vec(&request).map_err(|e| { + OpenAIError::InvalidArgument(format!("Failed to serialize request: {}", e)) + })?; + self.execute_with_body(Method::POST, path, Some(body.into())) + .await } /// Make a POST request to {path} and deserialize the response body @@ -262,17 +244,18 @@ impl Client { I: Serialize, O: DeserializeOwned, { - let request_maker = || async { - Ok(self - .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .json(&request) - .build()?) - }; + let body = serde_json::to_vec(&request).map_err(|e| { + OpenAIError::InvalidArgument(format!("Failed to serialize request: {}", e)) + })?; + + let bytes = self + .execute_with_body(Method::POST, path, Some(body.into())) + .await?; + + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; - self.execute(request_maker).await + Ok(response) } /// POST a form at {path} and return the response body @@ -281,17 +264,76 @@ impl Client { Form: AsyncTryFrom, F: Clone, { - let request_maker = || async { - Ok(self - .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .multipart(
>::try_from(form.clone()).await?) - .build()?) - }; + // Convert the form to our MultipartForm + let reqwest_form = >::try_from(form).await?; + let multipart = MultipartForm::from_reqwest_form(reqwest_form) + .await + .map_err(|e| OpenAIError::HttpClient(e.to_string()))?; + + // Build URL with query parameters + let url = self.config.url(path); + let mut parsed_url = Url::parse(&url) + .map_err(|e| OpenAIError::InvalidArgument(format!("Invalid URL: {}", e)))?; + for (key, value) in self.config.query() { + parsed_url.query_pairs_mut().append_pair(key, value); + } + + let client = self.http_client.clone(); + let headers = self.config.headers(); + + // Execute with backoff retry + backoff::future::retry(self.backoff.clone(), || async { + let response = client + .request_multipart( + Method::POST, + parsed_url.clone(), + headers.clone(), + multipart.clone(), + ) + .await + .map_err(|e| OpenAIError::HttpClient(e.to_string())) + .map_err(backoff::Error::Permanent)?; + + let status = response.status; + let bytes = response.body; - self.execute_raw(request_maker).await + if status.is_server_error() { + let message: String = String::from_utf8_lossy(&bytes).into_owned(); + tracing::warn!("Server error: {status} - {message}"); + return Err(backoff::Error::Transient { + err: OpenAIError::ApiError(ApiError { + message, + r#type: None, + param: None, + code: None, + }), + retry_after: None, + }); + } + + if !status.is_success() { + let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref())) + .map_err(backoff::Error::Permanent)?; + + if status.as_u16() == 429 + && wrapped_error.error.r#type != Some("insufficient_quota".to_string()) + { + tracing::warn!("Rate limited: {}", wrapped_error.error.message); + return Err(backoff::Error::Transient { + err: OpenAIError::ApiError(wrapped_error.error), + retry_after: None, + }); + } else { + return Err(backoff::Error::Permanent(OpenAIError::ApiError( + wrapped_error.error, + ))); + } + } + + Ok(bytes) + }) + .await } /// POST a form at {path} and deserialize the response body @@ -301,45 +343,44 @@ impl Client { Form: AsyncTryFrom, F: Clone, { - let request_maker = || async { - Ok(self - .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .multipart(>::try_from(form.clone()).await?) - .build()?) - }; - - self.execute(request_maker).await + let bytes = self.post_form_raw(path, form).await?; + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + Ok(response) } - /// Execute a HTTP request and retry on rate limit - /// - /// request_maker serves one purpose: to be able to create request again - /// to retry API call after getting rate limited. request_maker is async because - /// reqwest::multipart::Form is created by async calls to read files for uploads. - async fn execute_raw(&self, request_maker: M) -> Result - where - M: Fn() -> Fut, - Fut: core::future::Future>, - { + /// Execute an HTTP request with the HttpClient trait + async fn execute_with_body( + &self, + method: Method, + path: &str, + body: Option, + ) -> Result { let client = self.http_client.clone(); + let url = self.config.url(path); + let headers = self.config.headers(); + + // Build URL with query parameters + let mut parsed_url = Url::parse(&url) + .map_err(|e| OpenAIError::InvalidArgument(format!("Invalid URL: {}", e)))?; + for (key, value) in self.config.query() { + parsed_url.query_pairs_mut().append_pair(key, value); + } backoff::future::retry(self.backoff.clone(), || async { - let request = request_maker().await.map_err(backoff::Error::Permanent)?; let response = client - .execute(request) + .request( + method.clone(), + parsed_url.clone(), + headers.clone(), + body.clone(), + ) .await - .map_err(OpenAIError::Reqwest) + .map_err(|e| OpenAIError::HttpClient(e.to_string())) .map_err(backoff::Error::Permanent)?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(OpenAIError::Reqwest) - .map_err(backoff::Error::Permanent)?; + let status = response.status; + let bytes = response.body; if status.is_server_error() { // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. @@ -385,25 +426,6 @@ impl Client { .await } - /// Execute a HTTP request and retry on rate limit - /// - /// request_maker serves one purpose: to be able to create request again - /// to retry API call after getting rate limited. request_maker is async because - /// reqwest::multipart::Form is created by async calls to read files for uploads. - async fn execute(&self, request_maker: M) -> Result - where - O: DeserializeOwned, - M: Fn() -> Fut, - Fut: core::future::Future>, - { - let bytes = self.execute_raw(request_maker).await?; - - let response: O = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; - - Ok(response) - } - /// Make HTTP POST request to receive SSE pub(crate) async fn post_stream( &self, @@ -414,16 +436,35 @@ impl Client { I: Serialize, O: DeserializeOwned + std::marker::Send + 'static, { - let event_source = self + // Build URL with query parameters + let url = self.config.url(path); + let mut parsed_url = Url::parse(&url) + .map_err(|e| OpenAIError::InvalidArgument(format!("Invalid URL: {}", e))) + .unwrap(); // TODO: handle error properly + for (key, value) in self.config.query() { + parsed_url.query_pairs_mut().append_pair(key, value); + } + + // Serialize request body + let body = serde_json::to_vec(&request) + .map_err(|e| { + OpenAIError::InvalidArgument(format!("Failed to serialize request: {}", e)) + }) + .unwrap(); // TODO: handle error properly + + let event_stream = self .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .json(&request) - .eventsource() - .unwrap(); + .request_stream( + Method::POST, + parsed_url, + self.config.headers(), + Some(body.into()), + ) + .await + .map_err(|e| OpenAIError::HttpClient(e.to_string())) + .unwrap(); // TODO: handle error properly - stream(event_source).await + stream_from_sse(event_stream).await } pub(crate) async fn post_stream_mapped_raw_events( @@ -436,8 +477,10 @@ impl Client { I: Serialize, O: DeserializeOwned + std::marker::Send + 'static, { - let event_source = self - .http_client + // For now, keep using reqwest for the mapped events since it needs eventsource_stream::Event + // TODO: Update event_mapper to use our SseEvent type + let client = reqwest::Client::new(); + let event_source = client .post(self.config.url(path)) .query(&self.config.query()) .headers(self.config.headers()) @@ -458,60 +501,89 @@ impl Client { Q: Serialize + ?Sized, O: DeserializeOwned + std::marker::Send + 'static, { - let event_source = self + // Build URL with query parameters + let url = self.config.url(path); + let mut parsed_url = Url::parse(&url) + .map_err(|e| OpenAIError::InvalidArgument(format!("Invalid URL: {}", e))) + .unwrap(); // TODO: handle error properly + + // Add custom query + let query_string = serde_urlencoded::to_string(query) + .map_err(|e| OpenAIError::InvalidArgument(format!("Failed to serialize query: {}", e))) + .unwrap(); // TODO: handle error properly + if !query_string.is_empty() { + parsed_url.set_query(Some(&query_string)); + } + + // Add config query + for (key, value) in self.config.query() { + parsed_url.query_pairs_mut().append_pair(key, value); + } + + let event_stream = self .http_client - .get(self.config.url(path)) - .query(query) - .query(&self.config.query()) - .headers(self.config.headers()) - .eventsource() - .unwrap(); + .request_stream(Method::GET, parsed_url, self.config.headers(), None) + .await + .map_err(|e| OpenAIError::HttpClient(e.to_string())) + .unwrap(); // TODO: handle error properly - stream(event_source).await + stream_from_sse(event_stream).await } } -/// Request which responds with SSE. -/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) -pub(crate) async fn stream( - mut event_source: EventSource, +/// Convert our SSE stream to OpenAI response stream +pub(crate) async fn stream_from_sse( + mut event_stream: Pin< + Box> + Send>, + >, ) -> Pin> + Send>> where O: DeserializeOwned + std::marker::Send + 'static, { + use futures::StreamExt; + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { - while let Some(ev) = event_source.next().await { - match ev { + while let Some(event_result) = event_stream.next().await { + match event_result { Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + if let Err(_e) = tx.send(Err(OpenAIError::HttpClient(e.to_string()))) { // rx dropped break; } } - Ok(event) => match event { - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } + Ok(event) => { + // Check for [DONE] message + if event.data == "[DONE]" { + break; + } - let response = match serde_json::from_str::(&message.data) { - Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())), - Ok(output) => Ok(output), - }; + // Skip open events + if event.event.as_deref() == Some("open") { + continue; + } - if let Err(_e) = tx.send(response) { - // rx dropped - break; + // Try to parse the data as JSON + if !event.data.is_empty() { + match serde_json::from_str::(&event.data) { + Ok(obj) => { + if let Err(_e) = tx.send(Ok(obj)) { + // rx dropped + break; + } + } + Err(e) => { + if let Err(_e) = tx.send(Err(OpenAIError::JSONDeserialize(e))) { + // rx dropped + break; + } + } } } - Event::Open => continue, - }, + } } } - - event_source.close(); }); Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) diff --git a/async-openai/src/error.rs b/async-openai/src/error.rs index a1139c9f..3d25ccd4 100644 --- a/async-openai/src/error.rs +++ b/async-openai/src/error.rs @@ -6,6 +6,9 @@ pub enum OpenAIError { /// Underlying error from reqwest library after an API call was made #[error("http error: {0}")] Reqwest(#[from] reqwest::Error), + /// Error from HttpClient trait implementation + #[error("http client error: {0}")] + HttpClient(String), /// OpenAI returns error object with details of API call failure #[error("{0}")] ApiError(ApiError), diff --git a/async-openai/src/http_client.rs b/async-openai/src/http_client.rs new file mode 100644 index 00000000..33dfc0a3 --- /dev/null +++ b/async-openai/src/http_client.rs @@ -0,0 +1,241 @@ +/// HTTP client abstraction trait for async-openai +/// This allows using any HTTP client implementation, including those with middleware +use async_trait::async_trait; +use bytes::Bytes; +use futures::Stream; +use reqwest::{header::HeaderMap, Method, StatusCode, Url}; +use std::error::Error as StdError; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; + +/// Error type for HTTP operations +#[derive(Debug)] +pub struct HttpError { + pub message: String, + pub status: Option, +} + +impl fmt::Display for HttpError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.status { + Some(status) => write!(f, "HTTP {}: {}", status, self.message), + None => write!(f, "{}", self.message), + } + } +} + +impl StdError for HttpError {} + +impl From for HttpError { + fn from(err: reqwest::Error) -> Self { + HttpError { + message: err.to_string(), + status: err.status(), + } + } +} + +/// Response from HTTP client +pub struct HttpResponse { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: Bytes, +} + +/// Multipart form data for file uploads +#[derive(Clone)] +pub struct MultipartForm { + // Store the form as bytes after encoding + pub boundary: String, + pub body: Bytes, +} + +impl MultipartForm { + /// Convert a reqwest multipart form to our MultipartForm + /// This is a temporary helper until we have a better abstraction + pub async fn from_reqwest_form(form: reqwest::multipart::Form) -> Result { + use uuid::Uuid; + + // Generate a unique boundary + let boundary = format!("----FormBoundary{}", Uuid::new_v4().simple()); + + // Create a client to serialize the form + // This is a hack but reqwest doesn't expose form serialization directly + let client = reqwest::Client::new(); + let request = client + .post("http://localhost/dummy") // Dummy URL, we won't send this + .multipart(form) + .build() + .map_err(|e| HttpError { + message: format!("Failed to build multipart request: {}", e), + status: None, + })?; + + // Extract the body bytes + let body = request + .body() + .and_then(|b| b.as_bytes()) + .ok_or_else(|| HttpError { + message: "Failed to get multipart body bytes".to_string(), + status: None, + })?; + + Ok(MultipartForm { + boundary, + body: Bytes::copy_from_slice(body), + }) + } +} + +/// Server-sent event for streaming +#[derive(Debug, Clone)] +pub struct SseEvent { + pub data: String, + pub event: Option, + pub id: Option, + pub retry: Option, +} + +/// Trait for HTTP clients +/// This abstraction allows using reqwest::Client, ClientWithMiddleware, or any custom implementation +#[async_trait] +pub trait HttpClient: Send + Sync { + /// Send an HTTP request + async fn request( + &self, + method: Method, + url: Url, + headers: HeaderMap, + body: Option, + ) -> Result; + + /// Send a multipart form request + async fn request_multipart( + &self, + method: Method, + url: Url, + headers: HeaderMap, + form: MultipartForm, + ) -> Result; + + /// Send a request and receive Server-Sent Events stream + async fn request_stream( + &self, + method: Method, + url: Url, + headers: HeaderMap, + body: Option, + ) -> Result> + Send>>, HttpError>; +} + +/// Type alias for boxed HTTP client +pub type BoxedHttpClient = Arc; + +/// Implementation for standard reqwest::Client +#[async_trait] +impl HttpClient for reqwest::Client { + async fn request( + &self, + method: Method, + url: Url, + headers: HeaderMap, + body: Option, + ) -> Result { + let mut request = self.request(method, url).headers(headers); + + if let Some(body) = body { + request = request.body(body); + } + + let response = request.send().await?; + + let status = response.status(); + let headers = response.headers().clone(); + let body = response.bytes().await?; + + Ok(HttpResponse { + status, + headers, + body, + }) + } + + async fn request_multipart( + &self, + method: Method, + url: Url, + mut headers: HeaderMap, + form: MultipartForm, + ) -> Result { + use reqwest::header::{HeaderValue, CONTENT_TYPE}; + + // Set the multipart boundary in content-type header + let content_type = format!("multipart/form-data; boundary={}", form.boundary); + headers.insert( + CONTENT_TYPE, + HeaderValue::from_str(&content_type).map_err(|e| HttpError { + message: format!("Invalid content type: {}", e), + status: None, + })?, + ); + + let request = self.request(method, url).headers(headers).body(form.body); + + let response = request.send().await?; + + let status = response.status(); + let headers = response.headers().clone(); + let body = response.bytes().await?; + + Ok(HttpResponse { + status, + headers, + body, + }) + } + + async fn request_stream( + &self, + method: Method, + url: Url, + headers: HeaderMap, + body: Option, + ) -> Result> + Send>>, HttpError> { + use futures::StreamExt; + use reqwest_eventsource::{Event, RequestBuilderExt}; + + let mut request = self.request(method, url).headers(headers); + + if let Some(body) = body { + request = request.body(body); + } + + let event_source = request.eventsource().map_err(|e| HttpError { + message: format!("Failed to create event source: {}", e), + status: None, + })?; + + // Convert reqwest EventSource to our SseEvent stream + let stream = event_source.map(move |event| match event { + Ok(Event::Message(msg)) => Ok(SseEvent { + data: msg.data, + event: Some(msg.event), + id: Some(msg.id), + retry: msg.retry.map(|d| d.as_millis() as u64), + }), + Ok(Event::Open) => Ok(SseEvent { + data: String::new(), + event: Some("open".to_string()), + id: None, + retry: None, + }), + Err(e) => Err(HttpError { + message: format!("Stream error: {}", e), + status: None, + }), + }); + + Ok(Box::pin(stream)) + } +} diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index c94bc495..a9c03e86 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -153,6 +153,7 @@ mod embedding; pub mod error; mod file; mod fine_tuning; +pub mod http_client; mod image; mod invites; mod messages; diff --git a/examples/responses-stream/src/main.rs b/examples/responses-stream/src/main.rs index 5b565cd8..dafb2945 100644 --- a/examples/responses-stream/src/main.rs +++ b/examples/responses-stream/src/main.rs @@ -36,7 +36,9 @@ async fn main() -> Result<(), Box> { | ResponseEvent::ResponseFailed(_) => { break; } - _ => { println!("{response_event:#?}"); } + _ => { + println!("{response_event:#?}"); + } }, Err(e) => { eprintln!("{e:#?}");