1+ use async_trait:: async_trait;
2+ use crate :: llm:: { LlmGenerationClient , LlmSpec , LlmGenerateRequest , LlmGenerateResponse , ToJsonSchemaOptions , OutputFormat } ;
3+ use anyhow:: { Result , bail, Context } ;
4+ use serde_json:: Value ;
5+ use crate :: api_bail;
6+ use urlencoding:: encode;
7+
8+ pub struct Client {
9+ model : String ,
10+ api_key : String ,
11+ client : reqwest:: Client ,
12+ }
13+
14+ impl Client {
15+ pub async fn new ( spec : LlmSpec ) -> Result < Self > {
16+ let api_key = match std:: env:: var ( "GEMINI_API_KEY" ) {
17+ Ok ( val) => val,
18+ Err ( _) => api_bail ! ( "GEMINI_API_KEY environment variable must be set" ) ,
19+ } ;
20+ Ok ( Self {
21+ model : spec. model ,
22+ api_key,
23+ client : reqwest:: Client :: new ( ) ,
24+ } )
25+ }
26+ }
27+
28+ // Recursively remove all `additionalProperties` fields from a JSON value
29+ fn remove_additional_properties ( value : & mut Value ) {
30+ match value {
31+ Value :: Object ( map) => {
32+ map. remove ( "additionalProperties" ) ;
33+ for v in map. values_mut ( ) {
34+ remove_additional_properties ( v) ;
35+ }
36+ }
37+ Value :: Array ( arr) => {
38+ for v in arr {
39+ remove_additional_properties ( v) ;
40+ }
41+ }
42+ _ => { }
43+ }
44+ }
45+
46+ #[ async_trait]
47+ impl LlmGenerationClient for Client {
48+ async fn generate < ' req > (
49+ & self ,
50+ request : LlmGenerateRequest < ' req > ,
51+ ) -> Result < LlmGenerateResponse > {
52+ // Compose the prompt/messages
53+ let contents = vec ! [ serde_json:: json!( {
54+ "role" : "user" ,
55+ "parts" : [ { "text" : request. user_prompt } ]
56+ } ) ] ;
57+
58+ // Prepare payload
59+ let mut payload = serde_json:: json!( { "contents" : contents } ) ;
60+ if let Some ( system) = request. system_prompt {
61+ payload[ "systemInstruction" ] = serde_json:: json!( {
62+ "parts" : [ { "text" : system } ]
63+ } ) ;
64+ }
65+
66+ // If structured output is requested, add schema and responseMimeType
67+ if let Some ( OutputFormat :: JsonSchema { schema, .. } ) = & request. output_format {
68+ let mut schema_json = serde_json:: to_value ( schema) ?;
69+ remove_additional_properties ( & mut schema_json) ;
70+ payload[ "generationConfig" ] = serde_json:: json!( {
71+ "responseMimeType" : "application/json" ,
72+ "responseSchema" : schema_json
73+ } ) ;
74+ }
75+
76+ let api_key = & self . api_key ;
77+ let url = format ! (
78+ "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}" ,
79+ encode( & self . model) , encode( api_key)
80+ ) ;
81+
82+ let resp = self . client . post ( & url)
83+ . json ( & payload)
84+ . send ( )
85+ . await
86+ . context ( "HTTP error" ) ?;
87+
88+ let resp_json: Value = resp. json ( ) . await . context ( "Invalid JSON" ) ?;
89+
90+ if let Some ( error) = resp_json. get ( "error" ) {
91+ bail ! ( "Gemini API error: {:?}" , error) ;
92+ }
93+ let mut resp_json = resp_json;
94+ let text = match & mut resp_json[ "candidates" ] [ 0 ] [ "content" ] [ "parts" ] [ 0 ] [ "text" ] {
95+ Value :: String ( s) => std:: mem:: take ( s) ,
96+ _ => bail ! ( "No text in response" ) ,
97+ } ;
98+
99+ Ok ( LlmGenerateResponse { text } )
100+ }
101+
102+ fn json_schema_options ( & self ) -> ToJsonSchemaOptions {
103+ ToJsonSchemaOptions {
104+ fields_always_required : false ,
105+ supports_format : false ,
106+ extract_descriptions : false ,
107+ top_level_must_be_object : true ,
108+ }
109+ }
110+ }
0 commit comments