diff --git a/packages/baseai/src/data/models.ts b/packages/baseai/src/data/models.ts index 41759877..dc5334a6 100644 --- a/packages/baseai/src/data/models.ts +++ b/packages/baseai/src/data/models.ts @@ -421,43 +421,71 @@ export const modelsByProvider: ModelsByProviderInclCosts = { id: 'llama-3.1-70b-versatile', provider: GROQ, promptCost: 0.59, - completionCost: 0.79 + completionCost: 0.79, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'llama-3.1-8b-instant', provider: GROQ, promptCost: 0.59, - completionCost: 0.79 + completionCost: 0.79, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'llama3-70b-8192', provider: GROQ, promptCost: 0.59, - completionCost: 0.79 + completionCost: 0.79, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'llama3-8b-8192', provider: GROQ, promptCost: 0.05, - completionCost: 0.1 + completionCost: 0.1, + toolSupport: { + toolChoice: true, + parallelToolCalls: true + } }, { id: 'mixtral-8x7b-32768', provider: GROQ, promptCost: 0.27, - completionCost: 0.27 + completionCost: 0.27, + toolSupport: { + toolChoice: true, + parallelToolCalls: false + } }, { id: 'gemma2-9b-it', provider: GROQ, promptCost: 0.2, - completionCost: 0.2 + completionCost: 0.2, + toolSupport: { + toolChoice: true, + parallelToolCalls: false + } }, { id: 'gemma-7b-it', provider: GROQ, promptCost: 0.07, - completionCost: 0.07 + completionCost: 0.07, + toolSupport: { + toolChoice: true, + parallelToolCalls: false + } } ], [GOOGLE]: [ diff --git a/packages/baseai/src/dev/llms/call-groq.ts b/packages/baseai/src/dev/llms/call-groq.ts index 78089721..92baf4e1 100644 --- a/packages/baseai/src/dev/llms/call-groq.ts +++ b/packages/baseai/src/dev/llms/call-groq.ts @@ -5,6 +5,7 @@ import transformToProviderRequest from '../utils/provider-handlers/transfrom-to- import { applyJsonModeIfEnabled, handleLlmError } from './utils'; import type { ModelParams } from 'types/providers'; import type { Message } from 'types/pipe'; +import { addToolsToParams } from '../utils/add-tools-to-params'; export async function callGroq({ pipe, @@ -24,6 +25,7 @@ export async function callGroq({ baseURL: 'https://api.groq.com/openai/v1' }); applyJsonModeIfEnabled(modelParams, pipe); + addToolsToParams(modelParams, pipe); // Transform params according to provider's format const transformedRequestParams = transformToProviderRequest({ diff --git a/packages/baseai/src/dev/providers/groq/chatComplete.ts b/packages/baseai/src/dev/providers/groq/chatComplete.ts index f5d125d3..fa52f2ca 100644 --- a/packages/baseai/src/dev/providers/groq/chatComplete.ts +++ b/packages/baseai/src/dev/providers/groq/chatComplete.ts @@ -38,5 +38,17 @@ export const GroqChatCompleteConfig: ProviderConfig = { default: 1, max: 1, min: 1 + }, + parallel_tool_calls: { + param: 'parallel_tool_calls', + default: false + }, + tool_choice: { + param: 'tool_choice', + default: 'none' + }, + tools: { + param: 'tools', + default: [] } }; diff --git a/packages/baseai/src/dev/utils/add-tools-to-params.ts b/packages/baseai/src/dev/utils/add-tools-to-params.ts index 6cbef242..e4d144e8 100644 --- a/packages/baseai/src/dev/utils/add-tools-to-params.ts +++ b/packages/baseai/src/dev/utils/add-tools-to-params.ts @@ -1,22 +1,19 @@ -import { getSupportedToolSettings, hasToolSupport } from './has-tool-support'; +import { hasModelToolSupport } from './has-tool-support'; import type { ModelParams } from 'types/providers'; export function addToolsToParams(modelParams: ModelParams, pipe: any) { if (!pipe.functions.length) return; // Check if the model supports tool calls - const hasToolCallSupport = hasToolSupport({ - modelName: pipe.model.name, - provider: pipe.model.provider - }); + const { hasToolChoiceSupport, hasParallelToolCallSupport } = + hasModelToolSupport({ + modelName: pipe.model.name, + provider: pipe.model.provider + }); - if (hasToolCallSupport) { - const { hasParallelToolCallSupport, hasToolChoiceSupport } = - getSupportedToolSettings({ - modelName: pipe.model.name, - provider: pipe.model.provider - }); + const hasToolSupport = hasToolChoiceSupport || hasParallelToolCallSupport; + if (hasToolSupport) { if (hasParallelToolCallSupport) { modelParams.parallel_tool_calls = pipe.model.parallel_tool_calls; } diff --git a/packages/baseai/src/dev/utils/has-tool-support.ts b/packages/baseai/src/dev/utils/has-tool-support.ts index b7537a56..04c0f0b9 100644 --- a/packages/baseai/src/dev/utils/has-tool-support.ts +++ b/packages/baseai/src/dev/utils/has-tool-support.ts @@ -1,6 +1,6 @@ import { modelsByProvider } from '@/data/models'; -export function hasToolSupport({ +export function hasModelToolSupport({ provider, modelName }: { @@ -10,23 +10,7 @@ export function hasToolSupport({ const toolSupportedModels = modelsByProvider[provider].filter( model => model.toolSupport ); - const hasToolCallSupport = toolSupportedModels - .flatMap(model => model.id) - .includes(modelName); - return hasToolCallSupport; -} - -export function getSupportedToolSettings({ - provider, - modelName -}: { - modelName: string; - provider: string; -}) { - const toolSupportedModels = modelsByProvider[provider].filter( - model => model.toolSupport - ); const providerModel = toolSupportedModels.find( model => model.id === modelName ); diff --git a/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts b/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts index 0a2a7dc0..8f0edb63 100644 --- a/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts +++ b/packages/baseai/src/dev/utils/provider-handlers/provider-response-handler.ts @@ -1,10 +1,10 @@ -import type { ModelParams } from '@/types/providers'; import { handleNonStreamingMode, handleStreamingMode } from './response-handler-utils'; import Providers from '@/dev/providers'; import { dlog } from '../dlog'; +import type { ModelParams } from 'types/providers'; /** * Handles various types of responses based on the specified parameters