Skip to content

Commit b4e5ac0

Browse files
authored
📦 NEW: Google tool call support (#53)
1 parent cba3804 commit b4e5ac0

File tree

5 files changed

+267
-47
lines changed

5 files changed

+267
-47
lines changed

‎packages/baseai/src/dev/llms/call-google.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { GOOGLE } from '../data/models';
55
import { applyJsonModeIfEnabledForGoogle, handleLlmError } from './utils';
66
import type { ModelParams } from 'types/providers';
77
import type { Message } from 'types/pipe';
8+
import { addToolsToParams } from '../utils/add-tools-to-params';
89

910
export async function callGoogle({
1011
pipe,
@@ -19,6 +20,7 @@ export async function callGoogle({
1920
}) {
2021
try {
2122
const modelParams = buildModelParams(pipe, stream, messages);
23+
addToolsToParams(modelParams, pipe);
2224

2325
// Transform params according to provider's format
2426
const transformedRequestParams = transformToProviderRequest({

‎packages/baseai/src/dev/providers/google/chatComplete.ts

Lines changed: 224 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import {
22
generateErrorResponse,
3-
generateInvalidProviderResponseError
3+
generateInvalidProviderResponseError,
4+
getMimeType
45
} from '../utils';
56
import { GOOGLE } from '@/dev/data/models';
7+
import type { ToolCall, ToolChoice } from 'types/pipe';
68
import type {
79
ChatCompletionResponse,
810
ContentType,
911
ErrorResponse,
12+
MessageRole,
1013
ModelParams,
1114
ProviderConfig,
1215
ProviderMessage
@@ -32,6 +35,76 @@ const transformGenerationConfig = (params: ModelParams) => {
3235
return generationConfig;
3336
};
3437

38+
export type GoogleMessageRole = 'user' | 'model' | 'function';
39+
40+
interface GoogleFunctionCallMessagePart {
41+
functionCall: GoogleGenerateFunctionCall;
42+
}
43+
44+
interface GoogleFunctionResponseMessagePart {
45+
functionResponse: {
46+
name: string;
47+
response: {
48+
name?: string;
49+
content: string;
50+
};
51+
};
52+
}
53+
54+
type GoogleMessagePart =
55+
| GoogleFunctionCallMessagePart
56+
| GoogleFunctionResponseMessagePart
57+
| { text: string };
58+
59+
export interface GoogleMessage {
60+
role: GoogleMessageRole;
61+
parts: GoogleMessagePart[];
62+
}
63+
64+
export interface GoogleToolConfig {
65+
function_calling_config: {
66+
mode: GoogleToolChoiceType | undefined;
67+
allowed_function_names?: string[];
68+
};
69+
}
70+
71+
export const transformOpenAIRoleToGoogleRole = (
72+
role: MessageRole
73+
): GoogleMessageRole => {
74+
switch (role) {
75+
case 'assistant':
76+
return 'model';
77+
case 'tool':
78+
return 'function';
79+
// Not all gemini models support system role
80+
case 'system':
81+
return 'user';
82+
// user is the default role
83+
default:
84+
return role;
85+
}
86+
};
87+
88+
type GoogleToolChoiceType = 'AUTO' | 'ANY' | 'NONE';
89+
90+
export const transformToolChoiceForGemini = (
91+
tool_choice: ToolChoice
92+
): GoogleToolChoiceType | undefined => {
93+
if (typeof tool_choice === 'object' && tool_choice.type === 'function')
94+
return 'ANY';
95+
if (typeof tool_choice === 'string') {
96+
switch (tool_choice) {
97+
case 'auto':
98+
return 'AUTO';
99+
case 'none':
100+
return 'NONE';
101+
case 'required':
102+
return 'ANY';
103+
}
104+
}
105+
return undefined;
106+
};
107+
35108
export const GoogleChatCompleteConfig: ProviderConfig = {
36109
model: {
37110
param: 'model',
@@ -42,36 +115,100 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
42115
param: 'contents',
43116
default: '',
44117
transform: (params: ModelParams) => {
45-
const messages: { role: string; parts: { text: string }[] }[] = [];
118+
const messages: GoogleMessage[] = [];
119+
let lastRole: GoogleMessageRole | undefined;
46120

47121
params.messages?.forEach((message: ProviderMessage) => {
48-
const role = message.role === 'assistant' ? 'model' : 'user';
122+
const role = transformOpenAIRoleToGoogleRole(message.role);
49123
let parts = [];
50-
if (typeof message.content === 'string') {
124+
125+
if (message.role === 'assistant' && message.tool_calls) {
126+
message.tool_calls.forEach((tool_call: ToolCall) => {
127+
parts.push({
128+
functionCall: {
129+
name: tool_call.function.name,
130+
args: JSON.parse(tool_call.function.arguments)
131+
}
132+
});
133+
});
134+
} else if (
135+
message.role === 'tool' &&
136+
typeof message.content === 'string'
137+
) {
51138
parts.push({
52-
text: message.content
139+
functionResponse: {
140+
name: message.name ?? 'lb-random-tool-name',
141+
response: {
142+
content: message.content
143+
}
144+
}
53145
});
54-
}
55-
56-
if (message.content && typeof message.content === 'object') {
146+
} else if (
147+
message.content &&
148+
typeof message.content === 'object'
149+
) {
57150
message.content.forEach((c: ContentType) => {
58151
if (c.type === 'text') {
59152
parts.push({
60153
text: c.text
61154
});
62155
}
63156
if (c.type === 'image_url') {
64-
parts.push({
65-
inlineData: {
66-
mimeType: 'image/jpeg',
67-
data: c.image_url?.url
68-
}
69-
});
157+
const { url } = c.image_url || {};
158+
if (!url) return;
159+
160+
// Handle different types of image URLs
161+
if (url.startsWith('data:')) {
162+
const [mimeTypeWithPrefix, base64Image] =
163+
url.split(';base64,');
164+
const mimeType =
165+
mimeTypeWithPrefix.split(':')[1];
166+
167+
parts.push({
168+
inlineData: {
169+
mimeType: mimeType,
170+
data: base64Image
171+
}
172+
});
173+
} else if (
174+
url.startsWith('gs://') ||
175+
url.startsWith('https://') ||
176+
url.startsWith('http://')
177+
) {
178+
parts.push({
179+
fileData: {
180+
mimeType: getMimeType(url),
181+
fileUri: url
182+
}
183+
});
184+
} else {
185+
parts.push({
186+
inlineData: {
187+
mimeType: 'image/jpeg',
188+
data: c.image_url?.url
189+
}
190+
});
191+
}
70192
}
71193
});
194+
} else if (typeof message.content === 'string') {
195+
parts.push({
196+
text: message.content
197+
});
72198
}
73199

74-
messages.push({ role, parts });
200+
// Combine consecutive messages if they are from the same role
201+
// This takes care of the "Please ensure that multiturn requests alternate between user and model.
202+
// Also possible fix for "Please ensure that function call turn comes immediately after a user turn or after a function response turn." in parallel tool calls
203+
const shouldCombineMessages =
204+
lastRole === role && !params.model?.includes('vision');
205+
206+
if (shouldCombineMessages) {
207+
messages[messages.length - 1].parts.push(...parts);
208+
} else {
209+
messages.push({ role, parts });
210+
}
211+
lastRole = role;
75212
});
76213
return messages;
77214
}
@@ -108,6 +245,36 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
108245
});
109246
return [{ functionDeclarations }];
110247
}
248+
},
249+
tool_choice: {
250+
param: 'tool_config',
251+
default: '',
252+
transform: (params: ModelParams) => {
253+
if (params.tool_choice) {
254+
const allowedFunctionNames: string[] = [];
255+
// If tool_choice is an object and type is function, add the function name to allowedFunctionNames
256+
if (
257+
typeof params.tool_choice === 'object' &&
258+
params.tool_choice.type === 'function'
259+
) {
260+
allowedFunctionNames.push(params.tool_choice.function.name);
261+
}
262+
const toolConfig: GoogleToolConfig = {
263+
function_calling_config: {
264+
mode: transformToolChoiceForGemini(params.tool_choice)
265+
}
266+
};
267+
// TODO: @msaaddev I think we can't have more than one function in tool_choice
268+
// but this will also handle the case if we have more than one function in tool_choice
269+
270+
// If tool_choice has functions, add the function names to allowedFunctionNames
271+
if (allowedFunctionNames.length > 0) {
272+
toolConfig.function_calling_config.allowed_function_names =
273+
allowedFunctionNames;
274+
}
275+
return toolConfig;
276+
}
277+
}
111278
}
112279
};
113280

@@ -146,6 +313,11 @@ interface GoogleGenerateContentResponse {
146313
probability: string;
147314
}[];
148315
};
316+
usageMetadata: {
317+
promptTokenCount: number;
318+
candidatesTokenCount: number;
319+
totalTokenCount: number;
320+
};
149321
}
150322

151323
export const GoogleChatCompleteResponseTransform: (
@@ -170,7 +342,6 @@ export const GoogleChatCompleteResponseTransform: (
170342
GOOGLE
171343
);
172344
}
173-
174345
if ('candidates' in response) {
175346
return {
176347
id: crypto.randomUUID(),
@@ -179,7 +350,7 @@ export const GoogleChatCompleteResponseTransform: (
179350
model: 'Unknown',
180351
provider: GOOGLE,
181352
choices:
182-
response.candidates?.map((generation, index) => {
353+
response.candidates?.map(generation => {
183354
// In blocking mode: Google AI does not return content if response > max output tokens param
184355
// Test it by asking a big response while keeping maxtokens low ~ 50
185356
if (
@@ -203,28 +374,34 @@ export const GoogleChatCompleteResponseTransform: (
203374
} else if (generation.content?.parts[0]?.functionCall) {
204375
message = {
205376
role: 'assistant',
206-
tool_calls: [
207-
{
208-
id: crypto.randomUUID(),
209-
type: 'function',
210-
function: {
211-
name: generation.content.parts[0]
212-
?.functionCall.name,
213-
arguments: JSON.stringify(
214-
generation.content.parts[0]
215-
?.functionCall.args
216-
)
217-
}
377+
content: null,
378+
tool_calls: generation.content.parts.map(part => {
379+
if (part.functionCall) {
380+
return {
381+
id: crypto.randomUUID(),
382+
type: 'function',
383+
function: {
384+
name: part.functionCall.name,
385+
arguments: JSON.stringify(
386+
part.functionCall.args
387+
)
388+
}
389+
};
218390
}
219-
]
391+
})
220392
};
221393
}
222394
return {
223395
message: message,
224396
index: generation.index,
225397
finish_reason: generation.finishReason
226398
};
227-
}) ?? []
399+
}) ?? [],
400+
usage: {
401+
prompt_tokens: response.usageMetadata.promptTokenCount,
402+
completion_tokens: response.usageMetadata.candidatesTokenCount,
403+
total_tokens: response.usageMetadata.totalTokenCount
404+
}
228405
};
229406
}
230407

@@ -262,7 +439,7 @@ export const GoogleChatCompleteStreamChunkTransform: (
262439
model: '',
263440
provider: 'google',
264441
choices:
265-
parsedChunk.candidates?.map((generation, index) => {
442+
parsedChunk.candidates?.map(generation => {
266443
let message: ProviderMessage = {
267444
role: 'assistant',
268445
content: ''
@@ -275,21 +452,23 @@ export const GoogleChatCompleteStreamChunkTransform: (
275452
} else if (generation.content.parts[0]?.functionCall) {
276453
message = {
277454
role: 'assistant',
278-
tool_calls: [
279-
{
280-
id: crypto.randomUUID(),
281-
type: 'function',
282-
index: 0,
283-
function: {
284-
name: generation.content.parts[0]
285-
?.functionCall.name,
286-
arguments: JSON.stringify(
287-
generation.content.parts[0]
288-
?.functionCall.args
289-
)
455+
tool_calls: generation.content.parts.map(
456+
(part, idx) => {
457+
if (part.functionCall) {
458+
return {
459+
index: idx,
460+
id: crypto.randomUUID(),
461+
type: 'function',
462+
function: {
463+
name: part.functionCall.name,
464+
arguments: JSON.stringify(
465+
part.functionCall.args
466+
)
467+
}
468+
};
290469
}
291470
}
292-
]
471+
)
293472
};
294473
}
295474
return {

0 commit comments

Comments
 (0)