@@ -54,19 +54,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
5454
5555func (g * geminiClient ) convertMessages (messages []message.Message ) []* genai.Content {
5656 var history []* genai.Content
57-
58- // Add system message first
59- history = append (history , & genai.Content {
60- Parts : []genai.Part {genai .Text (g .providerOptions .systemMessage )},
61- Role : "user" ,
62- })
63-
64- // Add a system response to acknowledge the system message
65- history = append (history , & genai.Content {
66- Parts : []genai.Part {genai .Text ("I'll help you with that." )},
67- Role : "model" ,
68- })
69-
7057 for _ , msg := range messages {
7158 switch msg .Role {
7259 case message .User :
@@ -154,14 +141,11 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
154141}
155142
156143func (g * geminiClient ) finishReason (reason genai.FinishReason ) message.FinishReason {
157- reasonStr := reason .String ()
158144 switch {
159- case reasonStr == "STOP" :
145+ case reason == genai . FinishReasonStop :
160146 return message .FinishReasonEndTurn
161- case reasonStr == "MAX_TOKENS" :
147+ case reason == genai . FinishReasonMaxTokens :
162148 return message .FinishReasonMaxTokens
163- case strings .Contains (reasonStr , "FUNCTION" ) || strings .Contains (reasonStr , "TOOL" ):
164- return message .FinishReasonToolUse
165149 default :
166150 return message .FinishReasonUnknown
167151 }
@@ -170,7 +154,11 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
170154func (g * geminiClient ) send (ctx context.Context , messages []message.Message , tools []tools.BaseTool ) (* ProviderResponse , error ) {
171155 model := g .client .GenerativeModel (g .providerOptions .model .APIModel )
172156 model .SetMaxOutputTokens (int32 (g .providerOptions .maxTokens ))
173-
157+ model .SystemInstruction = & genai.Content {
158+ Parts : []genai.Part {
159+ genai .Text (g .providerOptions .systemMessage ),
160+ },
161+ }
174162 // Convert tools
175163 if len (tools ) > 0 {
176164 model .Tools = g .convertTools (tools )
@@ -188,19 +176,13 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
188176 attempts := 0
189177 for {
190178 attempts ++
179+ var toolCalls []message.ToolCall
191180 chat := model .StartChat ()
192181 chat .History = geminiMessages [:len (geminiMessages )- 1 ] // All but last message
193182
194183 lastMsg := geminiMessages [len (geminiMessages )- 1 ]
195- var lastText string
196- for _ , part := range lastMsg .Parts {
197- if text , ok := part .(genai.Text ); ok {
198- lastText = string (text )
199- break
200- }
201- }
202184
203- resp , err := chat .SendMessage (ctx , genai . Text ( lastText ) )
185+ resp , err := chat .SendMessage (ctx , lastMsg . Parts ... )
204186 // If there is an error we are going to see if we can retry the call
205187 if err != nil {
206188 retry , after , retryErr := g .shouldRetry (attempts , err )
@@ -220,7 +202,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
220202 }
221203
222204 content := ""
223- var toolCalls []message.ToolCall
224205
225206 if len (resp .Candidates ) > 0 && resp .Candidates [0 ].Content != nil {
226207 for _ , part := range resp .Candidates [0 ].Content .Parts {
@@ -231,28 +212,37 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
231212 id := "call_" + uuid .New ().String ()
232213 args , _ := json .Marshal (p .Args )
233214 toolCalls = append (toolCalls , message.ToolCall {
234- ID : id ,
235- Name : p .Name ,
236- Input : string (args ),
237- Type : "function" ,
215+ ID : id ,
216+ Name : p .Name ,
217+ Input : string (args ),
218+ Type : "function" ,
219+ Finished : true ,
238220 })
239221 }
240222 }
241223 }
224+ finishReason := g .finishReason (resp .Candidates [0 ].FinishReason )
225+ if len (toolCalls ) > 0 {
226+ finishReason = message .FinishReasonToolUse
227+ }
242228
243229 return & ProviderResponse {
244230 Content : content ,
245231 ToolCalls : toolCalls ,
246232 Usage : g .usage (resp ),
247- FinishReason : g . finishReason ( resp . Candidates [ 0 ]. FinishReason ) ,
233+ FinishReason : finishReason ,
248234 }, nil
249235 }
250236}
251237
252238func (g * geminiClient ) stream (ctx context.Context , messages []message.Message , tools []tools.BaseTool ) <- chan ProviderEvent {
253239 model := g .client .GenerativeModel (g .providerOptions .model .APIModel )
254240 model .SetMaxOutputTokens (int32 (g .providerOptions .maxTokens ))
255-
241+ model .SystemInstruction = & genai.Content {
242+ Parts : []genai.Part {
243+ genai .Text (g .providerOptions .systemMessage ),
244+ },
245+ }
256246 // Convert tools
257247 if len (tools ) > 0 {
258248 model .Tools = g .convertTools (tools )
@@ -276,18 +266,10 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
276266 for {
277267 attempts ++
278268 chat := model .StartChat ()
279- chat .History = geminiMessages [:len (geminiMessages )- 1 ] // All but last message
280-
269+ chat .History = geminiMessages [:len (geminiMessages )- 1 ]
281270 lastMsg := geminiMessages [len (geminiMessages )- 1 ]
282- var lastText string
283- for _ , part := range lastMsg .Parts {
284- if text , ok := part .(genai.Text ); ok {
285- lastText = string (text )
286- break
287- }
288- }
289271
290- iter := chat .SendMessageStream (ctx , genai . Text ( lastText ) )
272+ iter := chat .SendMessageStream (ctx , lastMsg . Parts ... )
291273
292274 currentContent := ""
293275 toolCalls := []message.ToolCall {}
@@ -330,23 +312,23 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
330312 for _ , part := range resp .Candidates [0 ].Content .Parts {
331313 switch p := part .(type ) {
332314 case genai.Text :
333- newText := string (p )
334- delta := newText [len (currentContent ):]
315+ delta := string (p )
335316 if delta != "" {
336317 eventChan <- ProviderEvent {
337318 Type : EventContentDelta ,
338319 Content : delta ,
339320 }
340- currentContent = newText
321+ currentContent += delta
341322 }
342323 case genai.FunctionCall :
343324 id := "call_" + uuid .New ().String ()
344325 args , _ := json .Marshal (p .Args )
345326 newCall := message.ToolCall {
346- ID : id ,
347- Name : p .Name ,
348- Input : string (args ),
349- Type : "function" ,
327+ ID : id ,
328+ Name : p .Name ,
329+ Input : string (args ),
330+ Type : "function" ,
331+ Finished : true ,
350332 }
351333
352334 isNew := true
@@ -368,37 +350,22 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
368350 eventChan <- ProviderEvent {Type : EventContentStop }
369351
370352 if finalResp != nil {
353+ finishReason := g .finishReason (finalResp .Candidates [0 ].FinishReason )
354+ if len (toolCalls ) > 0 {
355+ finishReason = message .FinishReasonToolUse
356+ }
371357 eventChan <- ProviderEvent {
372358 Type : EventComplete ,
373359 Response : & ProviderResponse {
374360 Content : currentContent ,
375361 ToolCalls : toolCalls ,
376362 Usage : g .usage (finalResp ),
377- FinishReason : g . finishReason ( finalResp . Candidates [ 0 ]. FinishReason ) ,
363+ FinishReason : finishReason ,
378364 },
379365 }
380366 return
381367 }
382368
383- // If we get here, we need to retry
384- if attempts > maxRetries {
385- eventChan <- ProviderEvent {
386- Type : EventError ,
387- Error : fmt .Errorf ("maximum retry attempts reached: %d retries" , maxRetries ),
388- }
389- return
390- }
391-
392- // Wait before retrying
393- select {
394- case <- ctx .Done ():
395- if ctx .Err () != nil {
396- eventChan <- ProviderEvent {Type : EventError , Error : ctx .Err ()}
397- }
398- return
399- case <- time .After (time .Duration (2000 * (1 << (attempts - 1 ))) * time .Millisecond ):
400- continue
401- }
402369 }
403370 }()
404371
0 commit comments