Skip to content

Commit cd5c76b

Browse files
committed
refactor to avoid repetition
1 parent 340aa07 commit cd5c76b

File tree

1 file changed

+56
-149
lines changed

1 file changed

+56
-149
lines changed

async-openai/src/client.rs

Lines changed: 56 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ impl<C: Config> Client<C> {
197197
&self.config
198198
}
199199

200+
/// Helper function to build a request builder with common configuration
201+
fn build_request_builder(
202+
&self,
203+
method: reqwest::Method,
204+
path: &str,
205+
request_options: &RequestOptions,
206+
) -> reqwest::RequestBuilder {
207+
let mut request_builder = self
208+
.http_client
209+
.request(method, self.config.url(path))
210+
.query(&self.config.query())
211+
.headers(self.config.headers());
212+
213+
if let Some(headers) = request_options.headers() {
214+
request_builder = request_builder.headers(headers.clone());
215+
}
216+
217+
if !request_options.query().is_empty() {
218+
request_builder = request_builder.query(request_options.query());
219+
}
220+
221+
request_builder
222+
}
223+
200224
/// Make a GET request to {path} and deserialize the response body
201225
pub(crate) async fn get<O>(
202226
&self,
@@ -207,21 +231,9 @@ impl<C: Config> Client<C> {
207231
O: DeserializeOwned,
208232
{
209233
let request_maker = || async {
210-
let mut request_builder = self
211-
.http_client
212-
.get(self.config.url(path))
213-
.query(&self.config.query())
214-
.headers(self.config.headers());
215-
216-
if let Some(headers) = request_options.headers() {
217-
request_builder = request_builder.headers(headers.clone());
218-
}
219-
220-
if !request_options.query().is_empty() {
221-
request_builder = request_builder.query(request_options.query());
222-
}
223-
224-
Ok(request_builder.build()?)
234+
Ok(self
235+
.build_request_builder(reqwest::Method::GET, path, request_options)
236+
.build()?)
225237
};
226238

227239
self.execute(request_maker).await
@@ -237,21 +249,9 @@ impl<C: Config> Client<C> {
237249
O: DeserializeOwned,
238250
{
239251
let request_maker = || async {
240-
let mut request_builder = self
241-
.http_client
242-
.delete(self.config.url(path))
243-
.query(&self.config.query())
244-
.headers(self.config.headers());
245-
246-
if let Some(headers) = request_options.headers() {
247-
request_builder = request_builder.headers(headers.clone());
248-
}
249-
250-
if !request_options.query().is_empty() {
251-
request_builder = request_builder.query(request_options.query());
252-
}
253-
254-
Ok(request_builder.build()?)
252+
Ok(self
253+
.build_request_builder(reqwest::Method::DELETE, path, request_options)
254+
.build()?)
255255
};
256256

257257
self.execute(request_maker).await
@@ -264,21 +264,9 @@ impl<C: Config> Client<C> {
264264
request_options: &RequestOptions,
265265
) -> Result<(Bytes, HeaderMap), OpenAIError> {
266266
let request_maker = || async {
267-
let mut request_builder = self
268-
.http_client
269-
.get(self.config.url(path))
270-
.query(&self.config.query())
271-
.headers(self.config.headers());
272-
273-
if let Some(headers) = request_options.headers() {
274-
request_builder = request_builder.headers(headers.clone());
275-
}
276-
277-
if !request_options.query().is_empty() {
278-
request_builder = request_builder.query(request_options.query());
279-
}
280-
281-
Ok(request_builder.build()?)
267+
Ok(self
268+
.build_request_builder(reqwest::Method::GET, path, request_options)
269+
.build()?)
282270
};
283271

284272
self.execute_raw(request_maker).await
@@ -295,22 +283,10 @@ impl<C: Config> Client<C> {
295283
I: Serialize,
296284
{
297285
let request_maker = || async {
298-
let mut request_builder = self
299-
.http_client
300-
.post(self.config.url(path))
301-
.query(&self.config.query())
302-
.headers(self.config.headers())
303-
.json(&request);
304-
305-
if let Some(headers) = request_options.headers() {
306-
request_builder = request_builder.headers(headers.clone());
307-
}
308-
309-
if !request_options.query().is_empty() {
310-
request_builder = request_builder.query(request_options.query());
311-
}
312-
313-
Ok(request_builder.build()?)
286+
Ok(self
287+
.build_request_builder(reqwest::Method::POST, path, request_options)
288+
.json(&request)
289+
.build()?)
314290
};
315291

316292
self.execute_raw(request_maker).await
@@ -328,22 +304,10 @@ impl<C: Config> Client<C> {
328304
O: DeserializeOwned,
329305
{
330306
let request_maker = || async {
331-
let mut request_builder = self
332-
.http_client
333-
.post(self.config.url(path))
334-
.query(&self.config.query())
335-
.headers(self.config.headers())
336-
.json(&request);
337-
338-
if let Some(headers) = request_options.headers() {
339-
request_builder = request_builder.headers(headers.clone());
340-
}
341-
342-
if !request_options.query().is_empty() {
343-
request_builder = request_builder.query(request_options.query());
344-
}
345-
346-
Ok(request_builder.build()?)
307+
Ok(self
308+
.build_request_builder(reqwest::Method::POST, path, request_options)
309+
.json(&request)
310+
.build()?)
347311
};
348312

349313
self.execute(request_maker).await
@@ -361,22 +325,10 @@ impl<C: Config> Client<C> {
361325
F: Clone,
362326
{
363327
let request_maker = || async {
364-
let mut request_builder = self
365-
.http_client
366-
.post(self.config.url(path))
367-
.query(&self.config.query())
368-
.headers(self.config.headers())
369-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
370-
371-
if let Some(headers) = request_options.headers() {
372-
request_builder = request_builder.headers(headers.clone());
373-
}
374-
375-
if !request_options.query().is_empty() {
376-
request_builder = request_builder.query(request_options.query());
377-
}
378-
379-
Ok(request_builder.build()?)
328+
Ok(self
329+
.build_request_builder(reqwest::Method::POST, path, request_options)
330+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
331+
.build()?)
380332
};
381333

382334
self.execute_raw(request_maker).await
@@ -395,22 +347,10 @@ impl<C: Config> Client<C> {
395347
F: Clone,
396348
{
397349
let request_maker = || async {
398-
let mut request_builder = self
399-
.http_client
400-
.post(self.config.url(path))
401-
.query(&self.config.query())
402-
.headers(self.config.headers())
403-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
404-
405-
if let Some(headers) = request_options.headers() {
406-
request_builder = request_builder.headers(headers.clone());
407-
}
408-
409-
if !request_options.query().is_empty() {
410-
request_builder = request_builder.query(request_options.query());
411-
}
412-
413-
Ok(request_builder.build()?)
350+
Ok(self
351+
.build_request_builder(reqwest::Method::POST, path, request_options)
352+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
353+
.build()?)
414354
};
415355

416356
self.execute(request_maker).await
@@ -429,20 +369,9 @@ impl<C: Config> Client<C> {
429369
{
430370
// Build and execute request manually since multipart::Form is not Clone
431371
// and .eventsource() requires cloneability
432-
let mut request_builder = self
433-
.http_client
434-
.post(self.config.url(path))
435-
.query(&self.config.query())
436-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
437-
.headers(self.config.headers());
438-
439-
if let Some(headers) = request_options.headers() {
440-
request_builder = request_builder.headers(headers.clone());
441-
}
442-
443-
if !request_options.query().is_empty() {
444-
request_builder = request_builder.query(request_options.query());
445-
}
372+
let request_builder = self
373+
.build_request_builder(reqwest::Method::POST, path, request_options)
374+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
446375

447376
let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
448377

@@ -580,21 +509,10 @@ impl<C: Config> Client<C> {
580509
I: Serialize,
581510
O: DeserializeOwned + std::marker::Send + 'static,
582511
{
583-
let mut request_builder = self
584-
.http_client
585-
.post(self.config.url(path))
586-
.query(&self.config.query())
587-
.headers(self.config.headers())
512+
let request_builder = self
513+
.build_request_builder(reqwest::Method::POST, path, request_options)
588514
.json(&request);
589515

590-
if let Some(headers) = request_options.headers() {
591-
request_builder = request_builder.headers(headers.clone());
592-
}
593-
594-
if !request_options.query().is_empty() {
595-
request_builder = request_builder.query(request_options.query());
596-
}
597-
598516
let event_source = request_builder.eventsource().unwrap();
599517

600518
stream(event_source).await
@@ -611,21 +529,10 @@ impl<C: Config> Client<C> {
611529
I: Serialize,
612530
O: DeserializeOwned + std::marker::Send + 'static,
613531
{
614-
let mut request_builder = self
615-
.http_client
616-
.post(self.config.url(path))
617-
.query(&self.config.query())
618-
.headers(self.config.headers())
532+
let request_builder = self
533+
.build_request_builder(reqwest::Method::POST, path, request_options)
619534
.json(&request);
620535

621-
if let Some(headers) = request_options.headers() {
622-
request_builder = request_builder.headers(headers.clone());
623-
}
624-
625-
if !request_options.query().is_empty() {
626-
request_builder = request_builder.query(request_options.query());
627-
}
628-
629536
let event_source = request_builder.eventsource().unwrap();
630537

631538
stream_mapped_raw_events(event_source, event_mapper).await

0 commit comments

Comments
 (0)