Skip to content
This repository was archived by the owner on Jul 16, 2025. It is now read-only.

Commit e3a0c0a

Browse files
authored
Tweaks (#42)
- make error handling report response body - support custom models (so when new ones are out, they can be used) - allow setting response_schema - safety_ratings is optional in response - upgrade reqwest and make default features optional
2 parents 5faab67 + d3a8616 commit e3a0c0a

File tree

5 files changed

+59
-17
lines changed

5 files changed

+59
-17
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ env_logger = { version = "0.11" }
1818
futures = { version = "0.3" }
1919
gcp_auth = { version = "0.12" }
2020
log = { version = "0.4.20" }
21-
reqwest = { version = "0.11", features = ["json"] }
22-
reqwest-streams = { version = "0.5.1", features = ["json"] }
21+
reqwest = { version = "0.12", default-features = false, features = ["json"] }
22+
reqwest-streams = { version = "0.8.2", default-features = false, features = ["json"] }
2323
serde = { version = "1.0", features = ["derive"] }
2424
serde_json = { version = "1.0" }
2525
tokio = { version = "1.35", features = ["full"] }

examples/text_request_json.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
6262
max_output_tokens: None,
6363
stop_sequences: None,
6464
response_mime_type: Some("application/json".to_string()),
65+
response_schema: None,
6566
}),
6667

6768
system_instruction: None,

src/v1/api.rs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Manages the interaction with the REST API for the Gemini API.
22
use futures::prelude::*;
33
use futures::stream::StreamExt;
4+
use reqwest::StatusCode;
45
use reqwest_streams::error::StreamBodyError;
56
use reqwest_streams::*;
67
use serde_json;
@@ -14,7 +15,7 @@ use crate::v1::gemini::request::Request;
1415
use crate::v1::gemini::response::GeminiResponse;
1516
use crate::v1::gemini::Model;
1617

17-
use super::gemini::response::{StreamedGeminiResponse, TokenCount};
18+
use super::gemini::response::{GeminiErrorResponse, StreamedGeminiResponse, TokenCount};
1819
use super::gemini::{ModelInformation, ModelInformationList, ResponseType};
1920

2021
#[cfg(feature = "beta")]
@@ -155,20 +156,31 @@ impl Client {
155156
.get_post_response(client, api_request, token_option)
156157
.await;
157158

158-
match result {
159-
Ok(response) => match response.status() {
160-
reqwest::StatusCode::OK => Ok(response.json::<GeminiResponse>().await.map_err(|e|GoogleAPIError {
161-
message: format!(
162-
"Failed to deserialize API response into v1::gemini::response::GeminiResponse: {}",
163-
e
164-
),
165-
code: None,
166-
})?),
167-
_ => Err(self.new_error_from_status_code(response.status())),
168-
},
169-
Err(e) => Err(self.new_error_from_reqwest_error(e)),
159+
if let Ok(result) = result {
160+
match result.status() {
161+
reqwest::StatusCode::OK => {
162+
Ok(result.json::<GeminiResponse>().await.map_err(|e|GoogleAPIError {
163+
message: format!(
164+
"Failed to deserialize API response into v1::gemini::response::GeminiResponse: {}",
165+
e
166+
),
167+
code: None,
168+
})?)
169+
},
170+
_ => {
171+
let status = result.status();
172+
173+
match result.json::<GeminiErrorResponse>().await {
174+
Ok(GeminiErrorResponse::Error { message, .. }) => Err(self.new_error_from_api_message(status, message)),
175+
Err(_) => Err(self.new_error_from_status_code(status)),
176+
}
177+
},
178+
}
179+
} else {
180+
Err(self.new_error_from_reqwest_error(result.unwrap_err()))
170181
}
171182
}
183+
172184
// Define the function that accepts the stream and the consumer
173185
/// A streamed post request
174186
async fn get_streamed_post_result(
@@ -410,6 +422,17 @@ impl Client {
410422
code: Some(code),
411423
}
412424
}
425+
426+
/// Creates a new error from a status code.
427+
fn new_error_from_api_message(&self, code: StatusCode, message: String) -> GoogleAPIError {
428+
let message = format!("API message: {message}.");
429+
430+
GoogleAPIError {
431+
message,
432+
code: Some(code),
433+
}
434+
}
435+
413436
/// Creates a new error from a reqwest error.
414437
fn new_error_from_reqwest_error(&self, mut e: reqwest::Error) -> GoogleAPIError {
415438
if let Some(url) = e.url_mut() {

src/v1/errors.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ impl fmt::Display for GoogleAPIError {
1111
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1212
write!(
1313
f,
14-
"GoogleAPIError - code: {} error: {:?}",
15-
self.message, self.code
14+
"GoogleAPIError - code: {:?} error: {}",
15+
self.code, self.message
1616
)
1717
}
1818
}

src/v1/gemini.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub enum Model {
8181
#[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
8282
Gemini1_5Pro,
8383
GeminiProVision,
84+
Custom(String),
8485
// TODO Embedding001
8586
}
8687
impl fmt::Display for Model {
@@ -92,6 +93,8 @@ impl fmt::Display for Model {
9293
Model::Gemini1_5Pro => write!(f, "gemini-1.5-pro-latest"),
9394

9495
Model::GeminiProVision => write!(f, "gemini-pro-vision"),
96+
97+
Model::Custom(name) => write!(f, "{}", name),
9598
// TODO Model::Embedding001 => write!(f, "embedding-001"),
9699
}
97100
}
@@ -319,6 +322,10 @@ pub mod request {
319322
#[cfg(feature = "beta")]
320323
#[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
321324
pub response_mime_type: Option<String>,
325+
326+
#[cfg(feature = "beta")]
327+
#[cfg_attr(docsrs, doc(cfg(feature = "beta")))]
328+
pub response_schema: Option<serde_json::Value>,
322329
}
323330

324331
#[cfg(feature = "beta")]
@@ -425,6 +432,16 @@ pub mod response {
425432
pub prompt_feedback: Option<PromptFeedback>,
426433
pub usage_metadata: Option<UsageMetadata>,
427434
}
435+
#[derive(Debug, Clone, Deserialize)]
436+
#[serde(rename_all = "camelCase")]
437+
pub enum GeminiErrorResponse {
438+
Error {
439+
code: u16,
440+
message: String,
441+
status: String,
442+
},
443+
}
444+
428445
impl GeminiResponse {
429446
/// Returns the total character count of the response as per the Gemini API.
430447
pub fn get_response_character_count(&self) -> usize {
@@ -448,6 +465,7 @@ pub mod response {
448465
pub content: Content,
449466
pub finish_reason: Option<String>,
450467
pub index: Option<i32>,
468+
#[serde(default)]
451469
pub safety_ratings: Vec<SafetyRating>,
452470
}
453471
#[derive(Debug, Clone, Deserialize)]

0 commit comments

Comments
 (0)