1
1
import {
2
2
generateErrorResponse ,
3
- generateInvalidProviderResponseError
3
+ generateInvalidProviderResponseError ,
4
+ getMimeType
4
5
} from '../utils' ;
5
6
import { GOOGLE } from '@/dev/data/models' ;
7
+ import type { ToolCall , ToolChoice } from 'types/pipe' ;
6
8
import type {
7
9
ChatCompletionResponse ,
8
10
ContentType ,
9
11
ErrorResponse ,
12
+ MessageRole ,
10
13
ModelParams ,
11
14
ProviderConfig ,
12
15
ProviderMessage
@@ -32,6 +35,76 @@ const transformGenerationConfig = (params: ModelParams) => {
32
35
return generationConfig ;
33
36
} ;
34
37
38
+ export type GoogleMessageRole = 'user' | 'model' | 'function' ;
39
+
40
+ interface GoogleFunctionCallMessagePart {
41
+ functionCall : GoogleGenerateFunctionCall ;
42
+ }
43
+
44
+ interface GoogleFunctionResponseMessagePart {
45
+ functionResponse : {
46
+ name : string ;
47
+ response : {
48
+ name ?: string ;
49
+ content : string ;
50
+ } ;
51
+ } ;
52
+ }
53
+
54
+ type GoogleMessagePart =
55
+ | GoogleFunctionCallMessagePart
56
+ | GoogleFunctionResponseMessagePart
57
+ | { text : string } ;
58
+
59
+ export interface GoogleMessage {
60
+ role : GoogleMessageRole ;
61
+ parts : GoogleMessagePart [ ] ;
62
+ }
63
+
64
+ export interface GoogleToolConfig {
65
+ function_calling_config : {
66
+ mode : GoogleToolChoiceType | undefined ;
67
+ allowed_function_names ?: string [ ] ;
68
+ } ;
69
+ }
70
+
71
+ export const transformOpenAIRoleToGoogleRole = (
72
+ role : MessageRole
73
+ ) : GoogleMessageRole => {
74
+ switch ( role ) {
75
+ case 'assistant' :
76
+ return 'model' ;
77
+ case 'tool' :
78
+ return 'function' ;
79
+ // Not all gemini models support system role
80
+ case 'system' :
81
+ return 'user' ;
82
+ // user is the default role
83
+ default :
84
+ return role ;
85
+ }
86
+ } ;
87
+
88
+ type GoogleToolChoiceType = 'AUTO' | 'ANY' | 'NONE' ;
89
+
90
+ export const transformToolChoiceForGemini = (
91
+ tool_choice : ToolChoice
92
+ ) : GoogleToolChoiceType | undefined => {
93
+ if ( typeof tool_choice === 'object' && tool_choice . type === 'function' )
94
+ return 'ANY' ;
95
+ if ( typeof tool_choice === 'string' ) {
96
+ switch ( tool_choice ) {
97
+ case 'auto' :
98
+ return 'AUTO' ;
99
+ case 'none' :
100
+ return 'NONE' ;
101
+ case 'required' :
102
+ return 'ANY' ;
103
+ }
104
+ }
105
+ return undefined ;
106
+ } ;
107
+
35
108
export const GoogleChatCompleteConfig : ProviderConfig = {
36
109
model : {
37
110
param : 'model' ,
@@ -42,36 +115,100 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
42
115
param : 'contents' ,
43
116
default : '' ,
44
117
transform : ( params : ModelParams ) => {
45
- const messages : { role : string ; parts : { text : string } [ ] } [ ] = [ ] ;
118
+ const messages : GoogleMessage [ ] = [ ] ;
119
+ let lastRole : GoogleMessageRole | undefined ;
46
120
47
121
params . messages ?. forEach ( ( message : ProviderMessage ) => {
48
- const role = message . role === 'assistant' ? 'model' : 'user' ;
122
+ const role = transformOpenAIRoleToGoogleRole ( message . role ) ;
49
123
let parts = [ ] ;
50
- if ( typeof message . content === 'string' ) {
124
+
125
+ if ( message . role === 'assistant' && message . tool_calls ) {
126
+ message . tool_calls . forEach ( ( tool_call : ToolCall ) => {
127
+ parts . push ( {
128
+ functionCall : {
129
+ name : tool_call . function . name ,
130
+ args : JSON . parse ( tool_call . function . arguments )
131
+ }
132
+ } ) ;
133
+ } ) ;
134
+ } else if (
135
+ message . role === 'tool' &&
136
+ typeof message . content === 'string'
137
+ ) {
51
138
parts . push ( {
52
- text : message . content
139
+ functionResponse : {
140
+ name : message . name ?? 'lb-random-tool-name' ,
141
+ response : {
142
+ content : message . content
143
+ }
144
+ }
53
145
} ) ;
54
- }
55
-
56
- if ( message . content && typeof message . content === 'object' ) {
146
+ } else if (
147
+ message . content &&
148
+ typeof message . content === 'object'
149
+ ) {
57
150
message . content . forEach ( ( c : ContentType ) => {
58
151
if ( c . type === 'text' ) {
59
152
parts . push ( {
60
153
text : c . text
61
154
} ) ;
62
155
}
63
156
if ( c . type === 'image_url' ) {
64
- parts . push ( {
65
- inlineData : {
66
- mimeType : 'image/jpeg' ,
67
- data : c . image_url ?. url
68
- }
69
- } ) ;
157
+ const { url } = c . image_url || { } ;
158
+ if ( ! url ) return ;
159
+
160
+ // Handle different types of image URLs
161
+ if ( url . startsWith ( 'data:' ) ) {
162
+ const [ mimeTypeWithPrefix , base64Image ] =
163
+ url . split ( ';base64,' ) ;
164
+ const mimeType =
165
+ mimeTypeWithPrefix . split ( ':' ) [ 1 ] ;
166
+
167
+ parts . push ( {
168
+ inlineData : {
169
+ mimeType : mimeType ,
170
+ data : base64Image
171
+ }
172
+ } ) ;
173
+ } else if (
174
+ url . startsWith ( 'gs://' ) ||
175
+ url . startsWith ( 'https://' ) ||
176
+ url . startsWith ( 'http://' )
177
+ ) {
178
+ parts . push ( {
179
+ fileData : {
180
+ mimeType : getMimeType ( url ) ,
181
+ fileUri : url
182
+ }
183
+ } ) ;
184
+ } else {
185
+ parts . push ( {
186
+ inlineData : {
187
+ mimeType : 'image/jpeg' ,
188
+ data : c . image_url ?. url
189
+ }
190
+ } ) ;
191
+ }
70
192
}
71
193
} ) ;
194
+ } else if ( typeof message . content === 'string' ) {
195
+ parts . push ( {
196
+ text : message . content
197
+ } ) ;
72
198
}
73
199
74
- messages . push ( { role, parts } ) ;
200
+ // Combine consecutive messages if they are from the same role
201
+ // This takes care of the "Please ensure that multiturn requests alternate between user and model.
202
+ // Also possible fix for "Please ensure that function call turn comes immediately after a user turn or after a function response turn." in parallel tool calls
203
+ const shouldCombineMessages =
204
+ lastRole === role && ! params . model ?. includes ( 'vision' ) ;
205
+
206
+ if ( shouldCombineMessages ) {
207
+ messages [ messages . length - 1 ] . parts . push ( ...parts ) ;
208
+ } else {
209
+ messages . push ( { role, parts } ) ;
210
+ }
211
+ lastRole = role ;
75
212
} ) ;
76
213
return messages ;
77
214
}
@@ -108,6 +245,36 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
108
245
} ) ;
109
246
return [ { functionDeclarations } ] ;
110
247
}
248
+ } ,
249
+ tool_choice : {
250
+ param : 'tool_config' ,
251
+ default : '' ,
252
+ transform : ( params : ModelParams ) => {
253
+ if ( params . tool_choice ) {
254
+ const allowedFunctionNames : string [ ] = [ ] ;
255
+ // If tool_choice is an object and type is function, add the function name to allowedFunctionNames
256
+ if (
257
+ typeof params . tool_choice === 'object' &&
258
+ params . tool_choice . type === 'function'
259
+ ) {
260
+ allowedFunctionNames . push ( params . tool_choice . function . name ) ;
261
+ }
262
+ const toolConfig : GoogleToolConfig = {
263
+ function_calling_config : {
264
+ mode : transformToolChoiceForGemini ( params . tool_choice )
265
+ }
266
+ } ;
267
+ // TODO: @msaaddev I think we can't have more than one function in tool_choice
268
+ // but this will also handle the case if we have more than one function in tool_choice
269
+
270
+ // If tool_choice has functions, add the function names to allowedFunctionNames
271
+ if ( allowedFunctionNames . length > 0 ) {
272
+ toolConfig . function_calling_config . allowed_function_names =
273
+ allowedFunctionNames ;
274
+ }
275
+ return toolConfig ;
276
+ }
277
+ }
111
278
}
112
279
} ;
113
280
@@ -146,6 +313,11 @@ interface GoogleGenerateContentResponse {
146
313
probability : string ;
147
314
} [ ] ;
148
315
} ;
316
+ usageMetadata : {
317
+ promptTokenCount : number ;
318
+ candidatesTokenCount : number ;
319
+ totalTokenCount : number ;
320
+ } ;
149
321
}
150
322
151
323
export const GoogleChatCompleteResponseTransform : (
@@ -170,7 +342,6 @@ export const GoogleChatCompleteResponseTransform: (
170
342
GOOGLE
171
343
) ;
172
344
}
173
-
174
345
if ( 'candidates' in response ) {
175
346
return {
176
347
id : crypto . randomUUID ( ) ,
@@ -179,7 +350,7 @@ export const GoogleChatCompleteResponseTransform: (
179
350
model : 'Unknown' ,
180
351
provider : GOOGLE ,
181
352
choices :
182
- response . candidates ?. map ( ( generation , index ) => {
353
+ response . candidates ?. map ( generation => {
183
354
// In blocking mode: Google AI does not return content if response > max output tokens param
184
355
// Test it by asking a big response while keeping maxtokens low ~ 50
185
356
if (
@@ -203,28 +374,34 @@ export const GoogleChatCompleteResponseTransform: (
203
374
} else if ( generation . content ?. parts [ 0 ] ?. functionCall ) {
204
375
message = {
205
376
role : 'assistant' ,
206
- tool_calls : [
207
- {
208
- id : crypto . randomUUID ( ) ,
209
- type : 'function' ,
210
- function : {
211
- name : generation . content . parts [ 0 ]
212
- ?. functionCall . name ,
213
- arguments : JSON . stringify (
214
- generation . content . parts [ 0 ]
215
- ?. functionCall . args
216
- )
217
- }
377
+ content : null ,
378
+ tool_calls : generation . content . parts . map ( part => {
379
+ if ( part . functionCall ) {
380
+ return {
381
+ id : crypto . randomUUID ( ) ,
382
+ type : 'function' ,
383
+ function : {
384
+ name : part . functionCall . name ,
385
+ arguments : JSON . stringify (
386
+ part . functionCall . args
387
+ )
388
+ }
389
+ } ;
218
390
}
219
- ]
391
+ } )
220
392
} ;
221
393
}
222
394
return {
223
395
message : message ,
224
396
index : generation . index ,
225
397
finish_reason : generation . finishReason
226
398
} ;
227
- } ) ?? [ ]
399
+ } ) ?? [ ] ,
400
+ usage : {
401
+ prompt_tokens : response . usageMetadata . promptTokenCount ,
402
+ completion_tokens : response . usageMetadata . candidatesTokenCount ,
403
+ total_tokens : response . usageMetadata . totalTokenCount
404
+ }
228
405
} ;
229
406
}
230
407
@@ -262,7 +439,7 @@ export const GoogleChatCompleteStreamChunkTransform: (
262
439
model : '' ,
263
440
provider : 'google' ,
264
441
choices :
265
- parsedChunk . candidates ?. map ( ( generation , index ) => {
442
+ parsedChunk . candidates ?. map ( generation => {
266
443
let message : ProviderMessage = {
267
444
role : 'assistant' ,
268
445
content : ''
@@ -275,21 +452,23 @@ export const GoogleChatCompleteStreamChunkTransform: (
275
452
} else if ( generation . content . parts [ 0 ] ?. functionCall ) {
276
453
message = {
277
454
role : 'assistant' ,
278
- tool_calls : [
279
- {
280
- id : crypto . randomUUID ( ) ,
281
- type : 'function' ,
282
- index : 0 ,
283
- function : {
284
- name : generation . content . parts [ 0 ]
285
- ?. functionCall . name ,
286
- arguments : JSON . stringify (
287
- generation . content . parts [ 0 ]
288
- ?. functionCall . args
289
- )
455
+ tool_calls : generation . content . parts . map (
456
+ ( part , idx ) => {
457
+ if ( part . functionCall ) {
458
+ return {
459
+ index : idx ,
460
+ id : crypto . randomUUID ( ) ,
461
+ type : 'function' ,
462
+ function : {
463
+ name : part . functionCall . name ,
464
+ arguments : JSON . stringify (
465
+ part . functionCall . args
466
+ )
467
+ }
468
+ } ;
290
469
}
291
470
}
292
- ]
471
+ )
293
472
} ;
294
473
}
295
474
return {
0 commit comments