11use async_trait:: async_trait;
22use crate :: llm:: { LlmGenerationClient , LlmSpec , LlmGenerateRequest , LlmGenerateResponse , ToJsonSchemaOptions , OutputFormat } ;
3- use anyhow:: { Result , anyhow} ;
4- use serde_json;
5- use reqwest:: Client as HttpClient ;
3+ use anyhow:: { Result , bail} ;
64use serde_json:: Value ;
5+ use crate :: api_bail;
6+ use urlencoding:: encode;
77
88pub struct Client {
99 model : String ,
10+ api_key : String ,
11+ client : reqwest:: Client ,
1012}
1113
1214impl Client {
1315 pub async fn new ( spec : LlmSpec ) -> Result < Self > {
14- if std:: env:: var ( "GEMINI_API_KEY" ) . is_err ( ) {
15- anyhow:: bail!( "GEMINI_API_KEY environment variable must be set" ) ;
16- }
16+ let api_key = match std:: env:: var ( "GEMINI_API_KEY" ) {
17+ Ok ( val) => val,
18+ Err ( _) => api_bail ! ( "GEMINI_API_KEY environment variable must be set" ) ,
19+ } ;
1720 Ok ( Self {
1821 model : spec. model ,
22+ api_key,
23+ client : reqwest:: Client :: new ( ) ,
1924 } )
2025 }
2126}
@@ -51,12 +56,11 @@ impl LlmGenerationClient for Client {
5156 } ) ] ;
5257
5358 // Optionally add system prompt
54- let mut system_instruction = None ;
55- if let Some ( system) = request. system_prompt {
56- system_instruction = Some ( serde_json:: json!( {
57- "parts" : [ { "text" : system } ]
58- } ) ) ;
59- }
59+ let system_instruction = request. system_prompt . map ( |system|
60+ serde_json:: json!( {
61+ "parts" : [ { "text" : system } ]
62+ } )
63+ ) ;
6064
6165 // Prepare payload
6266 let mut payload = serde_json:: json!( { "contents" : contents } ) ;
@@ -74,29 +78,33 @@ impl LlmGenerationClient for Client {
7478 } ) ;
7579 }
7680
77- let api_key = std:: env:: var ( "GEMINI_API_KEY" )
78- . map_err ( |_| anyhow ! ( "GEMINI_API_KEY environment variable must be set" ) ) ?;
81+ let api_key = & self . api_key ;
7982 let url = format ! (
8083 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}" ,
81- self . model, api_key
84+ encode ( & self . model) , encode ( api_key)
8285 ) ;
8386
84- let client = HttpClient :: new ( ) ;
85- let resp = client. post ( & url)
87+ let resp = match self . client . post ( & url)
8688 . json ( & payload)
8789 . send ( )
88- . await
89- . map_err ( |e| anyhow ! ( "HTTP error: {e}" ) ) ?;
90+ . await {
91+ Ok ( resp) => resp,
92+ Err ( e) => api_bail ! ( "HTTP error: {e}" ) ,
93+ } ;
9094
91- let resp_json: Value = resp. json ( ) . await . map_err ( |e| anyhow ! ( "Invalid JSON: {e}" ) ) ?;
95+ let resp_json: Value = match resp. json ( ) . await {
96+ Ok ( json) => json,
97+ Err ( e) => api_bail ! ( "Invalid JSON: {e}" ) ,
98+ } ;
9299
93100 if let Some ( error) = resp_json. get ( "error" ) {
94- return Err ( anyhow ! ( "Gemini API error: {:?}" , error) ) ;
101+ bail ! ( "Gemini API error: {:?}" , error) ;
95102 }
96- let text = resp_json[ "candidates" ] [ 0 ] [ "content" ] [ "parts" ] [ 0 ] [ "text" ]
97- . as_str ( )
98- . unwrap_or ( "" )
99- . to_string ( ) ;
103+ let mut resp_json = resp_json;
104+ let text = match & mut resp_json[ "candidates" ] [ 0 ] [ "content" ] [ "parts" ] [ 0 ] [ "text" ] {
105+ Value :: String ( s) => std:: mem:: take ( s) ,
106+ _ => bail ! ( "No text in response" ) ,
107+ } ;
100108
101109 Ok ( LlmGenerateResponse { text } )
102110 }
0 commit comments