Skip to content

Commit 2f2e11e

Browse files
authored
[Agent Builder] improve agent overall resilience (#241607)
## Summary Fix elastic/search-team#11720 Fix elastic/search-team#11700 Improve the agent execution's resilience by catching all potential errors returned from LLM calls and identifying the ones that we think we can recover from. At the moment, the "recoverable" errors we identify are: - calling a tool which is not available - calling a tool with invalid parameters - empty text responses For each type of error, we have specific logic to "represent" them to the agent so that it can try to work around it. For example for non-available tools, we return a tool response with an error stating that the tool isn't available. We retry a maximum of 3 times before giving up. Note that a successful call will reset that counter (so that if we face an error again in a later cycle, we get our error budget back) This PR also introduce a new type of error in the inference plugin, `ContextLengthExceededError`, which is thrown when the LLM call fails for, context length / too many tokens errors. I did that in the inference plugin because it makes sense to have that centralized and available to everyone. Note that I didn't implement connectivity-related recovery (e.g timeout, token budget and so on), because this is already implemented in the `inference` plugin.
1 parent 91aa8e0 commit 2f2e11e

File tree

27 files changed

+1087
-177
lines changed

27 files changed

+1087
-177
lines changed

x-pack/platform/packages/shared/ai-infra/inference-common/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export {
5555
type ChatCompletionToolValidationError,
5656
type ChatCompletionTokenLimitReachedError,
5757
isToolValidationError,
58-
isTokenLimitReachedError,
58+
isOutputTokenLimitReachedError,
5959
isToolNotFoundError,
6060
type ChatCompleteMetadata,
6161
type ConnectorTelemetryMetadata,

x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/errors.ts

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,26 @@ import type { UnvalidatedToolCall } from './tools';
1212
* List of code of error that are specific to the {@link ChatCompleteAPI}
1313
*/
1414
export enum ChatCompletionErrorCode {
15-
TokenLimitReachedError = 'tokenLimitReachedError',
15+
ContextLengthExceededError = 'contextLengthExceededError',
16+
OutputTokenLimitReachedError = 'outputTokenLimitReachedError',
1617
ToolNotFoundError = 'toolNotFoundError',
1718
ToolValidationError = 'toolValidationError',
1819
}
1920

2021
/**
21-
* Error thrown if the completion call fails because of a token limit
22-
* error, e.g. when the context window is higher than the limit
22+
* Error thrown if the completion call fails because of a context length error,
23+
* e.g. when too many input token or tool definitions are sent.
24+
*/
25+
export type ChatCompletionContextLengthExceededError = InferenceTaskError<
26+
ChatCompletionErrorCode.ContextLengthExceededError,
27+
{}
28+
>;
29+
30+
/**
31+
* Error thrown if the completion call fails because of an output token limit error
2332
*/
2433
export type ChatCompletionTokenLimitReachedError = InferenceTaskError<
25-
ChatCompletionErrorCode.TokenLimitReachedError,
34+
ChatCompletionErrorCode.OutputTokenLimitReachedError,
2635
{
2736
tokenLimit?: number;
2837
tokenCount?: number;
@@ -38,6 +47,8 @@ export type ChatCompletionToolNotFoundError = InferenceTaskError<
3847
{
3948
/** The name of the tool that got called */
4049
name: string;
50+
/** (unparsed) arguments the tool was called with*/
51+
arguments: string;
4152
}
4253
>;
4354

@@ -58,6 +69,18 @@ export type ChatCompletionToolValidationError = InferenceTaskError<
5869
}
5970
>;
6071

72+
/**
73+
* Check if an error is a {@link ChatCompletionContextLengthExceededError}
74+
*/
75+
export function isContextLengthExceededError(
76+
error: Error
77+
): error is ChatCompletionContextLengthExceededError {
78+
return (
79+
error instanceof InferenceTaskError &&
80+
error.code === ChatCompletionErrorCode.ContextLengthExceededError
81+
);
82+
}
83+
6184
/**
6285
* Check if an error is a {@link ChatCompletionToolValidationError}
6386
*/
@@ -71,12 +94,12 @@ export function isToolValidationError(error?: Error): error is ChatCompletionToo
7194
/**
7295
* Check if an error is a {@link ChatCompletionTokenLimitReachedError}
7396
*/
74-
export function isTokenLimitReachedError(
97+
export function isOutputTokenLimitReachedError(
7598
error: Error
7699
): error is ChatCompletionTokenLimitReachedError {
77100
return (
78101
error instanceof InferenceTaskError &&
79-
error.code === ChatCompletionErrorCode.TokenLimitReachedError
102+
error.code === ChatCompletionErrorCode.OutputTokenLimitReachedError
80103
);
81104
}
82105

x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ export {
7777
type ChatCompletionToolValidationError,
7878
type ChatCompletionTokenLimitReachedError,
7979
isToolValidationError,
80-
isTokenLimitReachedError,
80+
isOutputTokenLimitReachedError,
8181
isToolNotFoundError,
8282
} from './errors';
8383

x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/inference_chat_model.test.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,16 +689,15 @@ describe('InferenceChatModel', () => {
689689
});
690690
chatComplete.mockReturnValue(response);
691691

692-
const output = await chatModel.stream('Some question');
693-
694692
const allChunks: AIMessageChunk[] = [];
695693
await expect(async () => {
694+
const output = await chatModel.stream('Some question');
696695
for await (const chunk of output) {
697696
allChunks.push(chunk);
698697
}
699698
}).rejects.toThrowErrorMatchingInlineSnapshot(`"something went wrong"`);
700699

701-
expect(allChunks.length).toBe(2);
700+
expect(allChunks.length).toBe(0);
702701
});
703702
});
704703

x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/utils/observable_to_generator.test.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,30 @@ describe('toAsyncIterator', () => {
3939
}
4040
}).rejects.toThrowErrorMatchingInlineSnapshot(`"something went wrong"`);
4141

42-
expect(output).toEqual([1, 2, 3]);
42+
// Fail-fast behavior: queued values are discarded when error occurs
43+
expect(output).toEqual([]);
44+
});
45+
46+
it('throws an error when the source observable errors while iterator is waiting', async () => {
47+
const obs$ = new Observable<number>((subscriber) => {
48+
subscriber.next(1);
49+
subscriber.next(2);
50+
51+
// Delay before erroring, so the iterator will be waiting for the next value
52+
setTimeout(() => {
53+
subscriber.error(new Error('delayed error'));
54+
}, 10);
55+
});
56+
57+
const output: number[] = [];
58+
const iterator = toAsyncIterator(obs$);
59+
60+
await expect(async () => {
61+
for await (const event of iterator) {
62+
output.push(event);
63+
}
64+
}).rejects.toThrowErrorMatchingInlineSnapshot(`"delayed error"`);
65+
66+
expect(output).toEqual([1, 2]);
4367
});
4468
});

x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/utils/observable_to_generator.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export function toAsyncIterator<T>(observable: Observable<T>): AsyncIterableIter
1717

1818
const queue: Array<IteratorResult<T>> = [];
1919
let done = false;
20+
let error: any = null;
2021

2122
const subscription = observable.subscribe({
2223
next(value) {
@@ -28,11 +29,14 @@ export function toAsyncIterator<T>(observable: Observable<T>): AsyncIterableIter
2829
}
2930
},
3031
error(err) {
32+
done = true;
33+
error = err;
34+
// Clear any queued values - we fail fast
35+
queue.length = 0;
3136
if (reject) {
3237
reject(err);
3338
reject = null;
34-
} else {
35-
queue.push(Promise.reject(err) as any); // Queue an error
39+
resolve = null;
3640
}
3741
},
3842
complete() {
@@ -49,6 +53,11 @@ export function toAsyncIterator<T>(observable: Observable<T>): AsyncIterableIter
4953
return this;
5054
},
5155
next() {
56+
// Check for error first - fail fast
57+
if (error !== null) {
58+
return Promise.reject(error);
59+
}
60+
5261
if (queue.length > 0) {
5362
return Promise.resolve(queue.shift()!);
5463
}
@@ -66,9 +75,9 @@ export function toAsyncIterator<T>(observable: Observable<T>): AsyncIterableIter
6675
subscription.unsubscribe();
6776
return Promise.resolve({ value: undefined, done: true });
6877
},
69-
throw(error?: any) {
78+
throw(err?: any) {
7079
subscription.unsubscribe();
71-
return Promise.reject(error);
80+
return Promise.reject(err);
7281
},
7382
};
7483
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
export enum AgentExecutionErrorCode {
9+
/** too many input tokens */
10+
contextLengthExceeded = 'context_length_exceeded',
11+
/** agent called a tool not currently available */
12+
toolNotFound = 'tool_not_found',
13+
/** agent called a tool with invalid arguments */
14+
toolValidationError = 'tool_validation_error',
15+
/** agent replied with an empty response */
16+
emptyResponse = 'empty_response',
17+
/** any uncategorized error */
18+
unknownError = 'unknown_error',
19+
/** invalid workflow state - should never be surfaced */
20+
invalidState = 'invalid_state',
21+
}
22+
23+
export interface ToolNotFoundErrorMeta {
24+
/** name of the tool which was called */
25+
toolName: string;
26+
/** arguments the tool was called with */
27+
toolArgs: string | Record<string, any>;
28+
}
29+
30+
export interface TooValidationErrorMeta {
31+
/** name of the tool which was called */
32+
toolName: string;
33+
/** arguments the tool was called with */
34+
toolArgs: string | Record<string, any>;
35+
/** schema validation error, if any */
36+
validationError?: string;
37+
}
38+
39+
interface ExecutionErrorMetaMap {
40+
[AgentExecutionErrorCode.toolNotFound]: ToolNotFoundErrorMeta;
41+
[AgentExecutionErrorCode.toolValidationError]: TooValidationErrorMeta;
42+
[AgentExecutionErrorCode.contextLengthExceeded]: {};
43+
[AgentExecutionErrorCode.unknownError]: {};
44+
[AgentExecutionErrorCode.invalidState]: {};
45+
[AgentExecutionErrorCode.emptyResponse]: {};
46+
}
47+
48+
export type ExecutionErrorMetaOf<ErrCode extends AgentExecutionErrorCode> =
49+
ExecutionErrorMetaMap[ErrCode];

x-pack/platform/packages/shared/onechat/onechat-common/agents/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ export {
1919
type ResolvedAgentCapabilities,
2020
getKibanaDefaultAgentCapabilities,
2121
} from './capabilities';
22+
export { AgentExecutionErrorCode } from './execution_errors';

x-pack/platform/packages/shared/onechat/onechat-common/base/errors.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77

88
import { ServerSentEventError } from '@kbn/sse-utils';
9+
import type { AgentExecutionErrorCode, ExecutionErrorMetaOf } from '../agents/execution_errors';
910

1011
/**
1112
* Code to identify onechat errors
@@ -16,6 +17,7 @@ export enum OnechatErrorCode {
1617
toolNotFound = 'toolNotFound',
1718
agentNotFound = 'agentNotFound',
1819
conversationNotFound = 'conversationNotFound',
20+
agentExecutionError = 'agentExecutionError',
1921
requestAborted = 'requestAborted',
2022
}
2123

@@ -186,6 +188,34 @@ export const createRequestAbortedError = (
186188
return new OnechatError(OnechatErrorCode.requestAborted, message, meta ?? {});
187189
};
188190

191+
/**
192+
* Represents an error related to agent execution
193+
*/
194+
export type OnechatAgentExecutionError<
195+
ErrCode extends AgentExecutionErrorCode = AgentExecutionErrorCode
196+
> = OnechatError<
197+
OnechatErrorCode.agentExecutionError,
198+
{ errCode: ErrCode } & ExecutionErrorMetaOf<ErrCode>
199+
>;
200+
201+
/**
202+
* Checks if the given error is a {@link OnechatInternalError}
203+
*/
204+
export const isAgentExecutionError = (err: unknown): err is OnechatAgentExecutionError => {
205+
return isOnechatError(err) && err.code === OnechatErrorCode.agentExecutionError;
206+
};
207+
208+
export const createAgentExecutionError = <ErrCode extends AgentExecutionErrorCode>(
209+
message: string,
210+
code: ErrCode,
211+
meta: ExecutionErrorMetaOf<ErrCode>
212+
): OnechatAgentExecutionError<ErrCode> => {
213+
return new OnechatError(OnechatErrorCode.agentExecutionError, message, {
214+
...meta,
215+
errCode: code,
216+
});
217+
};
218+
189219
/**
190220
* Global utility exposing all error utilities from a single export.
191221
*/
@@ -195,8 +225,10 @@ export const OnechatErrorUtils = {
195225
isToolNotFoundError,
196226
isAgentNotFoundError,
197227
isConversationNotFoundError,
228+
isAgentExecutionError,
198229
createInternalError,
199230
createToolNotFoundError,
200231
createAgentNotFoundError,
201232
createConversationNotFoundError,
233+
createAgentExecutionError,
202234
};

x-pack/platform/packages/shared/onechat/onechat-genai-utils/langchain/messages.ts

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,17 @@ export const createToolResultMessage = ({
111111
});
112112
};
113113

114-
export const createToolCallMessage = (toolCall: ToolCall, message?: string): AIMessage => {
114+
export const createToolCallMessage = (
115+
toolCallOrCalls: ToolCall | ToolCall[],
116+
message?: string
117+
): AIMessage => {
118+
const toolCalls = isArray(toolCallOrCalls) ? toolCallOrCalls : [toolCallOrCalls];
115119
return new AIMessage({
116120
content: message ?? '',
117-
tool_calls: [
118-
{
119-
id: toolCall.toolCallId,
120-
name: toolCall.toolName,
121-
args: toolCall.args,
122-
},
123-
],
121+
tool_calls: toolCalls.map((toolCall) => ({
122+
id: toolCall.toolCallId,
123+
name: toolCall.toolName,
124+
args: toolCall.args,
125+
})),
124126
});
125127
};

0 commit comments

Comments
 (0)