Skip to content

Commit 0ff37d5

Browse files
committed
Refactor Gemini client
1 parent 05de041 commit 0ff37d5

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

src/llm/gemini.rs

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
use async_trait::async_trait;
22
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
3-
use anyhow::{Result, anyhow};
4-
use serde_json;
5-
use reqwest::Client as HttpClient;
3+
use anyhow::{Result, bail};
64
use serde_json::Value;
5+
use crate::api_bail;
6+
use urlencoding::encode;
77

88
pub struct Client {
99
model: String,
10+
api_key: String,
11+
client: reqwest::Client,
1012
}
1113

1214
impl Client {
1315
pub async fn new(spec: LlmSpec) -> Result<Self> {
14-
if std::env::var("GEMINI_API_KEY").is_err() {
15-
anyhow::bail!("GEMINI_API_KEY environment variable must be set");
16-
}
16+
let api_key = match std::env::var("GEMINI_API_KEY") {
17+
Ok(val) => val,
18+
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),
19+
};
1720
Ok(Self {
1821
model: spec.model,
22+
api_key,
23+
client: reqwest::Client::new(),
1924
})
2025
}
2126
}
@@ -51,12 +56,11 @@ impl LlmGenerationClient for Client {
5156
})];
5257

5358
// Optionally add system prompt
54-
let mut system_instruction = None;
55-
if let Some(system) = request.system_prompt {
56-
system_instruction = Some(serde_json::json!({
57-
"parts": [{ "text": system }]
58-
}));
59-
}
59+
let system_instruction = request.system_prompt.map(|system|
60+
serde_json::json!({
61+
"parts": [ { "text": system } ]
62+
})
63+
);
6064

6165
// Prepare payload
6266
let mut payload = serde_json::json!({ "contents": contents });
@@ -74,29 +78,33 @@ impl LlmGenerationClient for Client {
7478
});
7579
}
7680

77-
let api_key = std::env::var("GEMINI_API_KEY")
78-
.map_err(|_| anyhow!("GEMINI_API_KEY environment variable must be set"))?;
81+
let api_key = &self.api_key;
7982
let url = format!(
8083
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
81-
self.model, api_key
84+
encode(&self.model), encode(api_key)
8285
);
8386

84-
let client = HttpClient::new();
85-
let resp = client.post(&url)
87+
let resp = match self.client.post(&url)
8688
.json(&payload)
8789
.send()
88-
.await
89-
.map_err(|e| anyhow!("HTTP error: {e}"))?;
90+
.await {
91+
Ok(resp) => resp,
92+
Err(e) => api_bail!("HTTP error: {e}"),
93+
};
9094

91-
let resp_json: Value = resp.json().await.map_err(|e| anyhow!("Invalid JSON: {e}"))?;
95+
let resp_json: Value = match resp.json().await {
96+
Ok(json) => json,
97+
Err(e) => api_bail!("Invalid JSON: {e}"),
98+
};
9299

93100
if let Some(error) = resp_json.get("error") {
94-
return Err(anyhow!("Gemini API error: {:?}", error));
101+
bail!("Gemini API error: {:?}", error);
95102
}
96-
let text = resp_json["candidates"][0]["content"]["parts"][0]["text"]
97-
.as_str()
98-
.unwrap_or("")
99-
.to_string();
103+
let mut resp_json = resp_json;
104+
let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] {
105+
Value::String(s) => std::mem::take(s),
106+
_ => bail!("No text in response"),
107+
};
100108

101109
Ok(LlmGenerateResponse { text })
102110
}

0 commit comments

Comments
 (0)