Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 61 additions & 149 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,35 @@ impl<C: Config> Client<C> {
&self.config
}

/// Helper function to build a request builder with common configuration
fn build_request_builder(
&self,
method: reqwest::Method,
path: &str,
request_options: &RequestOptions,
) -> reqwest::RequestBuilder {
let mut request_builder = if let Some(path) = request_options.path() {
self.http_client
.request(method, self.config.url(path.as_str()))
} else {
self.http_client.request(method, self.config.url(path))
};

request_builder = request_builder
.query(&self.config.query())
.headers(self.config.headers());

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

request_builder
}

/// Make a GET request to {path} and deserialize the response body
pub(crate) async fn get<O>(
&self,
Expand All @@ -207,21 +236,9 @@ impl<C: Config> Client<C> {
O: DeserializeOwned,
{
let request_maker = || async {
let mut request_builder = self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers());

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::GET, path, request_options)
.build()?)
};

self.execute(request_maker).await
Expand All @@ -237,21 +254,9 @@ impl<C: Config> Client<C> {
O: DeserializeOwned,
{
let request_maker = || async {
let mut request_builder = self
.http_client
.delete(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers());

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::DELETE, path, request_options)
.build()?)
};

self.execute(request_maker).await
Expand All @@ -264,21 +269,9 @@ impl<C: Config> Client<C> {
request_options: &RequestOptions,
) -> Result<(Bytes, HeaderMap), OpenAIError> {
let request_maker = || async {
let mut request_builder = self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers());

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::GET, path, request_options)
.build()?)
};

self.execute_raw(request_maker).await
Expand All @@ -295,22 +288,10 @@ impl<C: Config> Client<C> {
I: Serialize,
{
let request_maker = || async {
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request);

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::POST, path, request_options)
.json(&request)
.build()?)
};

self.execute_raw(request_maker).await
Expand All @@ -328,22 +309,10 @@ impl<C: Config> Client<C> {
O: DeserializeOwned,
{
let request_maker = || async {
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request);

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::POST, path, request_options)
.json(&request)
.build()?)
};

self.execute(request_maker).await
Expand All @@ -361,22 +330,10 @@ impl<C: Config> Client<C> {
F: Clone,
{
let request_maker = || async {
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::POST, path, request_options)
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.build()?)
};

self.execute_raw(request_maker).await
Expand All @@ -395,22 +352,10 @@ impl<C: Config> Client<C> {
F: Clone,
{
let request_maker = || async {
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

Ok(request_builder.build()?)
Ok(self
.build_request_builder(reqwest::Method::POST, path, request_options)
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.build()?)
};

self.execute(request_maker).await
Expand All @@ -429,20 +374,9 @@ impl<C: Config> Client<C> {
{
// Build and execute request manually since multipart::Form is not Clone
// and .eventsource() requires cloneability
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.headers(self.config.headers());

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}
let request_builder = self
.build_request_builder(reqwest::Method::POST, path, request_options)
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);

let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;

Expand Down Expand Up @@ -580,21 +514,10 @@ impl<C: Config> Client<C> {
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
let request_builder = self
.build_request_builder(reqwest::Method::POST, path, request_options)
.json(&request);

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

let event_source = request_builder.eventsource().unwrap();

stream(event_source).await
Expand All @@ -611,21 +534,10 @@ impl<C: Config> Client<C> {
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let mut request_builder = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
let request_builder = self
.build_request_builder(reqwest::Method::POST, path, request_options)
.json(&request);

if let Some(headers) = request_options.headers() {
request_builder = request_builder.headers(headers.clone());
}

if !request_options.query().is_empty() {
request_builder = request_builder.query(request_options.query());
}

let event_source = request_builder.eventsource().unwrap();

stream_mapped_raw_events(event_source, event_mapper).await
Expand Down
16 changes: 16 additions & 0 deletions async-openai/src/request_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,28 @@ use crate::{config::OPENAI_API_BASE, error::OpenAIError};
pub struct RequestOptions {
query: Option<Vec<(String, String)>>,
headers: Option<HeaderMap>,
path: Option<String>,
}

impl RequestOptions {
pub(crate) fn new() -> Self {
Self {
query: None,
headers: None,
path: None,
}
}

pub(crate) fn with_path(&mut self, path: &str) -> Result<(), OpenAIError> {
if path.is_empty() {
return Err(OpenAIError::InvalidArgument(
"Path cannot be empty".to_string(),
));
}
self.path = Some(path.to_string());
Ok(())
}

pub(crate) fn with_headers(&mut self, headers: HeaderMap) {
// merge with existing headers or update with new headers
if let Some(existing_headers) = &mut self.headers {
Expand Down Expand Up @@ -81,4 +93,8 @@ impl RequestOptions {
pub(crate) fn headers(&self) -> Option<&HeaderMap> {
self.headers.as_ref()
}

pub(crate) fn path(&self) -> Option<&String> {
self.path.as_ref()
}
}
6 changes: 6 additions & 0 deletions async-openai/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,10 @@ pub trait RequestOptionsBuilder: Sized {
self.options_mut().with_query(query)?;
Ok(self)
}

/// Add a path to RequestOptions
fn path<P: Into<String>>(mut self, path: P) -> Result<Self, OpenAIError> {
self.options_mut().with_path(path.into().as_str())?;
Ok(self)
}
}
Loading