diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/conversation_complete.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/conversation_complete.ts index e2ba6ce763554..222e3c5478196 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/conversation_complete.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/conversation_complete.ts @@ -112,7 +112,6 @@ export enum ChatCompletionErrorCode { InternalError = 'internalError', NotFoundError = 'notFoundError', TokenLimitReachedError = 'tokenLimitReachedError', - FunctionNotFoundError = 'functionNotFoundError', FunctionLimitExceededError = 'functionLimitExceededError', } @@ -123,9 +122,6 @@ interface ErrorMetaAttributes { tokenLimit?: number; tokenCount?: number; }; - [ChatCompletionErrorCode.FunctionNotFoundError]: { - name: string; - }; [ChatCompletionErrorCode.FunctionLimitExceededError]: {}; } @@ -170,13 +166,6 @@ export function createInternalServerError( return new ChatCompletionError(ChatCompletionErrorCode.InternalError, originalErrorMessage); } -export function createFunctionNotFoundError(name: string) { - return new ChatCompletionError( - ChatCompletionErrorCode.FunctionNotFoundError, - `Function "${name}" called but was not available` - ); -} - export function createFunctionLimitExceededError() { return new ChatCompletionError( ChatCompletionErrorCode.FunctionLimitExceededError, @@ -193,15 +182,6 @@ export function isTokenLimitReachedError( ); } -export function isFunctionNotFoundError( - error: Error -): error is ChatCompletionError { - return ( - error instanceof ChatCompletionError && - error.code === ChatCompletionErrorCode.FunctionNotFoundError - ); -} - export function isChatCompletionError(error: Error): error is ChatCompletionError { return error instanceof ChatCompletionError; } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/index.ts index 831415c1d6e64..453f155255b7e 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/index.ts @@ -39,7 +39,6 @@ export { createInternalServerError, isTokenLimitReachedError, isChatCompletionError, - createFunctionNotFoundError, } from './conversation_complete'; export { diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/chat_function_client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/chat_function_client/index.ts index 03edc5c74033a..943868bc7e0f3 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/chat_function_client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/chat_function_client/index.ts @@ -4,11 +4,11 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -/* eslint-disable max-classes-per-file*/ -import Ajv, { type ErrorObject, type ValidateFunction } from 'ajv'; +import Ajv, { type ValidateFunction } from 'ajv'; import { compact, keyBy } from 'lodash'; import type { Logger } from '@kbn/logging'; +import { createToolValidationError } from '@kbn/inference-plugin/common/chat_complete/errors'; import { type FunctionResponse } from '../../../common/functions/types'; import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../../common/types'; import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions'; @@ -22,12 +22,6 @@ import type { } from '../types'; import { registerGetDataOnScreenFunction } from '../../functions/get_data_on_screen'; -export class FunctionArgsValidationError extends Error { - constructor(public readonly errors: ErrorObject[]) { - super('Function arguments are invalid'); - } -} - const ajv = new Ajv({ strict: false, }); @@ -71,7 +65,12 @@ export class ChatFunctionClient { const result = validator(parameters); if (!result) { - throw new FunctionArgsValidationError(validator.errors!); + throw createToolValidationError(`Tool call arguments for ${name} were invalid`, { + name, + errorsText: validator.errors?.map((error) => error.message).join(', '), + arguments: JSON.stringify(parameters), + toolCalls: [], + }); } } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts index 475c92a22f7e7..d9ad40716e74f 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts @@ -67,7 +67,6 @@ import type { ChatFunctionClient } from '../chat_function_client'; import type { KnowledgeBaseService, RecalledEntry } from '../knowledge_base_service'; import { getAccessQuery } from '../util/get_access_query'; import { getSystemMessageFromInstructions } from '../util/get_system_message_from_instructions'; -import { failOnNonExistingFunctionCall } from './operators/fail_on_non_existing_function_call'; import { getContextFunctionRequestIfNeeded } from './get_context_function_request_if_needed'; import { continueConversation } from './operators/continue_conversation'; import { convertInferenceEventsToStreamingEvents } from './operators/convert_inference_events_to_streaming_events'; @@ -578,7 +577,6 @@ export class ObservabilityAIAssistantClient { }) ).pipe( convertInferenceEventsToStreamingEvents(), - failOnNonExistingFunctionCall({ functions }), tap((event) => { if (event.type === StreamingChatResponseEventType.ChatCompletionChunk) { this.dependencies.logger.trace( diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/catch_function_not_found_error.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/catch_function_not_found_error.ts index 61bc81a26fe1f..774f42246c7e6 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/catch_function_not_found_error.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/catch_function_not_found_error.ts @@ -7,16 +7,16 @@ import type { OperatorFunction } from 'rxjs'; import { catchError, filter, of, share, throwError } from 'rxjs'; +import { v4 } from 'uuid'; import { i18n } from '@kbn/i18n'; +import { isToolNotFoundError } from '@kbn/inference-common'; import { MessageRole } from '../../../../common'; import type { ChatCompletionChunkEvent, + MessageAddEvent, MessageOrChatEvent, } from '../../../../common/conversation_complete'; -import { - isFunctionNotFoundError, - StreamingChatResponseEventType, -} from '../../../../common/conversation_complete'; +import { StreamingChatResponseEventType } from '../../../../common/conversation_complete'; import { emitWithConcatenatedMessage } from '../../../../common/utils/emit_with_concatenated_message'; function appendFunctionLimitExceededErrorMessageToAssistantResponse(): OperatorFunction< @@ -75,13 +75,30 @@ export function catchFunctionNotFoundError( return shared$.pipe( catchError((error) => { - if (isFunctionNotFoundError(error)) { + if (isToolNotFoundError(error)) { if (functionLimitExceeded) { return chunksWithoutErrors$.pipe( appendFunctionLimitExceededErrorMessageToAssistantResponse() ); } - return chunksWithoutErrors$.pipe(emitWithConcatenatedMessage()); + // Instead of throwing error, return a message with the function name, to be handled by the function client + const simpleMessage: MessageAddEvent = { + type: StreamingChatResponseEventType.MessageAdd as const, + id: v4(), + message: { + '@timestamp': new Date().toISOString(), + message: { + content: '', + role: MessageRole.Assistant, + function_call: { + name: error.meta.name, + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + }; + return of(simpleMessage); } return throwError(() => error); }) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts index 2bb423ab6daa4..89bcc29ac1172 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -23,16 +23,17 @@ import { throwError, } from 'rxjs'; import { withExecuteToolSpan } from '@kbn/inference-tracing'; +import { createToolNotFoundError } from '@kbn/inference-plugin/common/chat_complete/errors'; import type { AnalyticsServiceStart } from '@kbn/core/server'; import type { Connector } from '@kbn/actions-plugin/server'; import type { AssistantScope } from '@kbn/ai-assistant-common'; +import { isToolValidationError } from '@kbn/inference-common'; import { getInferenceConnectorInfo } from '../../../../common/utils/get_inference_connector'; import type { ToolCallEvent } from '../../../analytics/tool_call'; import { toolCallEventType } from '../../../analytics/tool_call'; import type { Message, CompatibleJSONSchema, MessageAddEvent } from '../../../../common'; import { CONTEXT_FUNCTION_NAME, - createFunctionNotFoundError, MessageRole, StreamingChatResponseEventType, } from '../../../../common'; @@ -111,6 +112,15 @@ export function executeFunctionAndCatchError({ catchError((error) => { span?.recordException(error); logger.error(`Encountered error running function ${name}: ${JSON.stringify(error)}`); + + if (isToolValidationError(error)) { + return of( + createFunctionResponseMessage({ + name, + content: { message: error.message, errors: error.meta }, + }) + ); + } // We want to catch the error only when a promise occurs // if it occurs in the Observable, we cannot easily recover // from it because the function may have already emitted @@ -321,7 +331,7 @@ export function continueConversation({ return of( createServerSideFunctionResponseError({ name: functionCallName, - error: createFunctionNotFoundError(functionCallName), + error: createToolNotFoundError(functionCallName), }) ); } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/fail_on_non_existing_function_call.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/fail_on_non_existing_function_call.ts deleted file mode 100644 index 678c243b3cae5..0000000000000 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/fail_on_non_existing_function_call.ts +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import type { Observable } from 'rxjs'; -import { ignoreElements, last, merge, shareReplay, tap } from 'rxjs'; -import type { FunctionDefinition } from '../../../../common'; -import { createFunctionNotFoundError } from '../../../../common'; -import type { ChatEvent } from '../../../../common/conversation_complete'; -import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks'; - -export function failOnNonExistingFunctionCall({ - functions, -}: { - functions?: Array>; -}) { - return (source$: Observable) => { - const shared$ = source$.pipe(shareReplay()); - - return merge( - shared$, - shared$.pipe( - concatenateChatCompletionChunks(), - last(), - tap((event) => { - if ( - event.message.function_call.name && - functions?.find((fn) => fn.name === event.message.function_call.name) === undefined - ) { - throw createFunctionNotFoundError(event.message.function_call.name); - } - }), - ignoreElements() - ) - ); - }; -} diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts deleted file mode 100644 index ece268ec1737b..0000000000000 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/util/catch_function_limit_exceeded_error.ts +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { i18n } from '@kbn/i18n'; -import type { OperatorFunction } from 'rxjs'; -import { catchError, filter, of, shareReplay, throwError } from 'rxjs'; -import type { ChatCompletionChunkEvent } from '../../../common'; -import { MessageRole, StreamingChatResponseEventType } from '../../../common'; -import type { MessageOrChatEvent } from '../../../common/conversation_complete'; -import { isFunctionNotFoundError } from '../../../common/conversation_complete'; -import { emitWithConcatenatedMessage } from '../../../common/utils/emit_with_concatenated_message'; - -export function catchFunctionLimitExceededError(): OperatorFunction< - MessageOrChatEvent, - MessageOrChatEvent -> { - return (source$) => { - const shared$ = source$.pipe(shareReplay()); - const chunksWithoutErrors$ = shared$.pipe( - catchError(() => of()), - shareReplay() - ); - - return shared$.pipe( - catchError((error) => { - if (isFunctionNotFoundError(error)) { - const withInjectedErrorMessage$ = chunksWithoutErrors$.pipe( - filter( - (msg): msg is ChatCompletionChunkEvent => - msg.type === StreamingChatResponseEventType.ChatCompletionChunk - ), - emitWithConcatenatedMessage(async (concatenatedMessage) => { - return { - ...concatenatedMessage, - message: { - ...concatenatedMessage.message, - content: `${concatenatedMessage.message.content}\n\n${i18n.translate( - 'xpack.observabilityAiAssistant.functionCallLimitExceeded', - { - defaultMessage: - '\n\nNote: the Assistant tried to call a function, even though the limit was exceeded', - } - )}`, - function_call: { - name: '', - arguments: '', - trigger: MessageRole.Assistant, - }, - }, - }; - }) - ); - - return withInjectedErrorMessage$; - } - return throwError(() => error); - }) - ); - }; -} diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json b/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json index cc425234fa43f..543a735137218 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json @@ -62,7 +62,7 @@ "@kbn/product-doc-base-plugin", "@kbn/inference-endpoint-plugin", "@kbn/spaces-utils", - "@kbn/usage-collection-plugin" + "@kbn/usage-collection-plugin", ], "exclude": ["target/**/*"] } diff --git a/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/complete/complete.spec.ts b/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/complete/complete.spec.ts index 80758bf782ee6..548e7ff5cd677 100644 --- a/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/complete/complete.spec.ts +++ b/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/complete/complete.spec.ts @@ -31,6 +31,9 @@ import { clearConversations, decodeEvents, getConversationCreatedEvent, + invokeChatCompleteWithFunctionRequest, + getMessageAddedEvents, + chatComplete, } from '../utils/conversation'; interface NonStreamingChatResponse { @@ -567,6 +570,263 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); }); + describe('when calling a tool', () => { + describe('when calling a tool that is not available', () => { + before(async () => { + proxy.close(); + proxy = await createLlmProxy(log); + connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ + port: proxy.getPort(), + }); + }); + after(async () => { + proxy.close(); + await observabilityAIAssistantAPIClient.deleteActionConnector({ + actionId: connectorId, + }); + }); + describe('when invoking the chat complete with the tool request', function () { + let events: MessageAddEvent[]; + + before(async () => { + void proxy.interceptWithResponse('Hello from LLM Proxy'); + + const responseBody = await invokeChatCompleteWithFunctionRequest({ + connectorId, + observabilityAIAssistantAPIClient, + functionCall: { + name: 'unknown_tool', + trigger: MessageRole.User, + arguments: JSON.stringify({ + foo: 'bar', + }), + }, + }); + + await proxy.waitForAllInterceptorsToHaveBeenCalled(); + + events = getMessageAddedEvents(responseBody); + }); + + it('returns 2 message add events', () => { + expect(events.length).to.be(2); + }); + + it('the first message add event has the tool name and an error', () => { + expect(events[0].message.message.name).to.be('unknown_tool'); + expect(events[0].message.message.content).to.contain('toolNotFoundError'); + }); + + it('the second message add event interact with the LLM to fix the error', () => { + expect(events[1].message.message.content).to.be('Hello from LLM Proxy'); + }); + }); + + describe('when the LLM calls a tool that is not available', function () { + let messageAddedEvents: MessageAddEvent[]; + let fullConversation: Conversation; + before(async () => { + void proxy.interceptTitle('LLM-generated title'); + + void proxy.interceptWithFunctionRequest({ + name: 'unknown_tool', + arguments: () => + JSON.stringify({ + foo: 'bar', + }), + when: () => true, + }); + + void proxy.interceptWithResponse('Hello from LLM Proxy, again!'); + + const { messageAddedEvents: messageAddedEventsResponse, conversationCreateEvent } = + await chatComplete({ + userPrompt: 'user prompt test spec', + connectorId, + persist: true, + observabilityAIAssistantAPIClient, + }); + messageAddedEvents = messageAddedEventsResponse; + + await proxy.waitForAllInterceptorsToHaveBeenCalled(); + + const conversationId = conversationCreateEvent.conversation.id; + const conversationResponse = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/conversation/{conversationId}', + params: { + path: { + conversationId, + }, + }, + }); + expect(conversationResponse.status).to.be(200); + fullConversation = conversationResponse.body; + }); + + after(async () => { + await clearConversations(es); + }); + + it('makes 4 requests to the LLM', () => { + expect(proxy.interceptedRequests.length).to.be(4); + }); + + it('emits 5 messageAdded events', () => { + expect(messageAddedEvents.length).to.be(5); + }); + + it('conversation has the correct messages', () => { + expect(fullConversation.messages.length).to.be(6); + // user prompt + expect(fullConversation.messages[0].message.content).to.be('user prompt test spec'); + // context function call + expect(fullConversation.messages[1].message.function_call?.name).to.be('context'); + // context function response + expect(fullConversation.messages[2].message.name).to.be('context'); + // unknown tool function call + expect(fullConversation.messages[3].message.function_call?.name).to.be('unknown_tool'); + // unknown tool function response with error message + expect(fullConversation.messages[4].message.name).to.contain('unknown_tool'); + expect(fullConversation.messages[4].message.content).to.contain('toolNotFoundError'); + // interaction with the LLM to fix the error + expect(fullConversation.messages[5].message.content).to.be( + 'Hello from LLM Proxy, again!' + ); + }); + }); + }); + + describe('when calling a tool with invalid arguments', () => { + before(async () => { + proxy.close(); + proxy = await createLlmProxy(log); + connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ + port: proxy.getPort(), + }); + }); + after(async () => { + proxy.close(); + await observabilityAIAssistantAPIClient.deleteActionConnector({ + actionId: connectorId, + }); + }); + describe('when invoking the chat complete with the function request with invalid arguments', function () { + let events: MessageAddEvent[]; + + before(async () => { + void proxy.interceptWithResponse('Hello from LLM Proxy'); + + const responseBody = await invokeChatCompleteWithFunctionRequest({ + connectorId, + observabilityAIAssistantAPIClient, + functionCall: { + name: 'kibana', + trigger: MessageRole.User, + arguments: JSON.stringify({ + foo: 'bar', + }), + }, + }); + + await proxy.waitForAllInterceptorsToHaveBeenCalled(); + + events = getMessageAddedEvents(responseBody); + }); + + it('returns 2 message add events', () => { + expect(events.length).to.be(2); + }); + + it('the first message add event has the tool name and an error', () => { + expect(events[0].message.message.name).to.be('kibana'); + expect(events[0].message.message.content).to.contain( + 'Tool call arguments for kibana were invalid' + ); + }); + + it('the second message add event interact with the LLM to fix the error', () => { + expect(events[1].message.message.content).to.be('Hello from LLM Proxy'); + }); + }); + + describe('when the LLM calls a tool with invalid arguments', function () { + let messageAddedEvents: MessageAddEvent[]; + let fullConversation: Conversation; + before(async () => { + void proxy.interceptTitle('LLM-generated title'); + + void proxy.interceptWithFunctionRequest({ + name: 'kibana', + arguments: () => + JSON.stringify({ + foo: 'bar', + }), + when: () => true, + }); + + void proxy.interceptWithResponse('I will not call the kibana function!'); + void proxy.interceptWithResponse('Hello from LLM Proxy, again!'); + + const { messageAddedEvents: messageAddedEventsResponse, conversationCreateEvent } = + await chatComplete({ + userPrompt: 'user prompt test spec', + connectorId, + persist: true, + observabilityAIAssistantAPIClient, + }); + messageAddedEvents = messageAddedEventsResponse; + + await proxy.waitForAllInterceptorsToHaveBeenCalled(); + + const conversationId = conversationCreateEvent.conversation.id; + const conversationResponse = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/conversation/{conversationId}', + params: { + path: { + conversationId, + }, + }, + }); + expect(conversationResponse.status).to.be(200); + fullConversation = conversationResponse.body; + }); + + after(async () => { + await clearConversations(es); + }); + + it('makes 5 requests to the LLM', () => { + expect(proxy.interceptedRequests.length).to.be(5); + }); + + it('emits 5 messageAdded events', () => { + expect(messageAddedEvents.length).to.be(5); + }); + + it('conversation has the correct messages', () => { + expect(fullConversation.messages.length).to.be(6); + // user prompt + expect(fullConversation.messages[0].message.content).to.be('user prompt test spec'); + // context function call + expect(fullConversation.messages[1].message.function_call?.name).to.be('context'); + // context function response + expect(fullConversation.messages[2].message.name).to.be('context'); + // kibana function call + expect(fullConversation.messages[3].message.function_call?.name).to.be('kibana'); + // kibana function response with error message + expect(fullConversation.messages[4].message.name).to.contain('kibana'); + expect(fullConversation.messages[4].message.content).to.contain( + 'Tool call arguments for kibana were invalid' + ); + // interaction with the LLM to fix the error + expect(fullConversation.messages[5].message.content).to.be( + 'Hello from LLM Proxy, again!' + ); + }); + }); + }); + }); + describe('security roles and access privileges', () => { it('should deny access for users without the ai_assistant privilege', async () => { const { status } = await observabilityAIAssistantAPIClient.viewer({