Skip to content

Commit 0621aea

Browse files
OrKoNDevtools-frontend LUCI CQ
authored andcommitted
Consolidate function calling
This PR removes duplication in handling side effects and function calling by introducing a function calling emulation mode for StylingAgent. Bug: 360751542 Change-Id: I383ade655828e655700444a971a27de13ce291fa Reviewed-on: https://chromium-review.googlesource.com/c/devtools/devtools-frontend/+/6286165 Commit-Queue: Alex Rudenko <[email protected]> Reviewed-by: Nikolay Vitkov <[email protected]> Reviewed-by: Ergün Erdoğmuş <[email protected]>
1 parent 5efc7e9 commit 0621aea

File tree

5 files changed

+222
-293
lines changed

5 files changed

+222
-293
lines changed

front_end/panels/ai_assistance/agents/AiAgent.ts

Lines changed: 131 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,12 @@ export interface FunctionDeclaration<Args extends Record<string, unknown>, Retur
187187
*/
188188
parameters: Host.AidaClient.FunctionObjectParam<keyof Args>;
189189
/**
190-
* Provided a way to give information back to
191-
* the UI before running the the handler
190+
* Provided a way to give information back to the UI.
192191
*/
193192
displayInfoFromArgs?: (
194193
args: Args,
195194
) => {
196-
title?: string, thought?: string, code?: string, suggestions?: [string, ...string[]],
195+
title?: string, thought?: string, action?: string, suggestions?: [string, ...string[]],
197196
};
198197
/**
199198
* Function implementation that the LLM will try to execute,
@@ -208,8 +207,25 @@ export interface FunctionDeclaration<Args extends Record<string, unknown>, Retur
208207
}) => Promise<FunctionCallHandlerResult<ReturnType>>;
209208
}
210209

211-
const OBSERVATION_PREFIX = 'OBSERVATION:';
210+
const OBSERVATION_PREFIX = 'OBSERVATION: ';
212211

212+
interface AidaFetchResult {
213+
text?: string;
214+
functionCall?: Host.AidaClient.AidaFunctionCallResponse;
215+
completed: boolean;
216+
rpcId?: Host.AidaClient.RpcGlobalId;
217+
}
218+
219+
/**
220+
* AiAgent is a base class for implementing an interaction with AIDA
221+
* that involves one or more requests being sent to AIDA optionally
222+
* utilizing function calling.
223+
*
224+
* TODO: missing a test that action code is yielded before the
225+
* confirmation dialog.
226+
* TODO: missing a test for an error if it took
227+
* more than MAX_STEPS iterations.
228+
*/
213229
export abstract class AiAgent<T> {
214230
/** Subclasses need to define these. */
215231
abstract readonly type: AgentType;
@@ -274,6 +290,7 @@ export abstract class AiAgent<T> {
274290
function validTemperature(temperature: number|undefined): number|undefined {
275291
return typeof temperature === 'number' && temperature >= 0 ? temperature : undefined;
276292
}
293+
const enableAidaFunctionCalling = declarations.length && !this.functionCallEmulationEnabled;
277294
const request: Host.AidaClient.AidaRequest = {
278295
client: Host.AidaClient.CLIENT_NAME,
279296

@@ -282,7 +299,7 @@ export abstract class AiAgent<T> {
282299

283300
historical_contexts: history.length ? history : undefined,
284301

285-
...(declarations.length ? {function_declarations: declarations} : {}),
302+
...(enableAidaFunctionCalling ? {function_declarations: declarations} : {}),
286303
options: {
287304
temperature: validTemperature(this.options.temperature),
288305
model_id: this.options.modelId,
@@ -294,8 +311,8 @@ export abstract class AiAgent<T> {
294311
client_version: Root.Runtime.getChromeVersion(),
295312
},
296313

297-
functionality_type: declarations.length ? Host.AidaClient.FunctionalityType.AGENTIC_CHAT :
298-
Host.AidaClient.FunctionalityType.CHAT,
314+
functionality_type: enableAidaFunctionCalling ? Host.AidaClient.FunctionalityType.AGENTIC_CHAT :
315+
Host.AidaClient.FunctionalityType.CHAT,
299316

300317
client_feature: this.clientFeature,
301318
};
@@ -314,13 +331,13 @@ export abstract class AiAgent<T> {
314331
return this.#origin;
315332
}
316333

317-
parseResponse(response: Host.AidaClient.AidaResponse): ParsedResponse {
318-
if (response.functionCalls && response.completed) {
319-
throw new Error('Function calling not supported yet');
320-
}
321-
return {
322-
answer: response.explanation,
323-
};
334+
/**
335+
* Parses a streaming text response into a
336+
* though/action/title/answer/suggestions component. This is only used
337+
* by StylingAgent.
338+
*/
339+
parseTextResponse(response: string): ParsedResponse {
340+
return {answer: response};
324341
}
325342

326343
/**
@@ -346,10 +363,14 @@ export abstract class AiAgent<T> {
346363
return answer;
347364
}
348365

349-
protected handleAction(action: string, options?: {signal?: AbortSignal}):
350-
AsyncGenerator<SideEffectResponse, ActionResponse, void>;
351-
protected handleAction(): never {
352-
throw new Error('Unexpected action found');
366+
/**
367+
* Special mode for StylingAgent that turns custom text output into a
368+
* function call.
369+
*/
370+
protected functionCallEmulationEnabled = false;
371+
protected emulateFunctionCall(_aidaResponse: Host.AidaClient.AidaResponse): Host.AidaClient.AidaFunctionCallResponse|
372+
'no-function-call'|'wait-for-completion' {
373+
throw new Error('Unexpected emulateFunctionCall. Only StylingAgent implements function call emulation');
353374
}
354375

355376
async *
@@ -391,23 +412,30 @@ export abstract class AiAgent<T> {
391412
};
392413

393414
let rpcId: Host.AidaClient.RpcGlobalId|undefined;
394-
let parsedResponse: ParsedResponse|undefined = undefined;
415+
let textResponse = '';
395416
let functionCall: Host.AidaClient.AidaFunctionCallResponse|undefined = undefined;
396417
try {
397418
for await (const fetchResult of this.#aidaFetch(request, {signal: options.signal})) {
398419
rpcId = fetchResult.rpcId;
399-
parsedResponse = fetchResult.parsedResponse;
420+
textResponse = fetchResult.text ?? '';
400421
functionCall = fetchResult.functionCall;
401422

402-
// Only yield partial responses here and do not add partial answers to the history.
403-
if (!fetchResult.completed && !fetchResult.functionCall && 'answer' in parsedResponse &&
404-
parsedResponse.answer) {
423+
if (!functionCall && !fetchResult.completed) {
424+
const parsed = this.parseTextResponse(textResponse);
425+
const partialAnswer = 'answer' in parsed ? parsed.answer : '';
426+
if (!partialAnswer) {
427+
continue;
428+
}
429+
// Only yield partial responses here and do not add partial answers to the history.
405430
yield {
406431
type: ResponseType.ANSWER,
407-
text: parsedResponse.answer,
432+
text: partialAnswer,
408433
complete: false,
409434
};
410435
}
436+
if (functionCall) {
437+
break;
438+
}
411439
}
412440
} catch (err) {
413441
debugLog('Error calling the AIDA API', err);
@@ -425,7 +453,11 @@ export abstract class AiAgent<T> {
425453

426454
this.#history.push(request.current_message);
427455

428-
if (parsedResponse && 'answer' in parsedResponse && Boolean(parsedResponse.answer)) {
456+
if (textResponse) {
457+
const parsedResponse = this.parseTextResponse(textResponse);
458+
if (!('answer' in parsedResponse)) {
459+
throw new Error('Expected a completed response to have an answer');
460+
}
429461
this.#history.push({
430462
parts: [{
431463
text: this.formatParsedAnswer(parsedResponse),
@@ -441,66 +473,24 @@ export abstract class AiAgent<T> {
441473
rpcId,
442474
};
443475
break;
444-
} else if (parsedResponse && !('answer' in parsedResponse)) {
445-
const {
446-
title,
447-
thought,
448-
action,
449-
} = parsedResponse;
450-
451-
if (title) {
452-
yield {
453-
type: ResponseType.TITLE,
454-
title,
455-
rpcId,
456-
};
457-
}
458-
459-
if (thought) {
460-
yield {
461-
type: ResponseType.THOUGHT,
462-
thought,
463-
rpcId,
464-
};
465-
}
466-
467-
this.#history.push({
468-
parts: [{
469-
text: this.#formatParsedStep(parsedResponse),
470-
}],
471-
role: Host.AidaClient.Role.MODEL,
472-
});
476+
}
473477

474-
if (action) {
475-
const result = yield* this.handleAction(action, {signal: options.signal});
476-
if (options?.signal?.aborted) {
478+
if (functionCall) {
479+
try {
480+
const result = yield* this.#callFunction(functionCall.name, functionCall.args, options);
481+
if (options.signal?.aborted) {
477482
yield this.#createErrorResponse(ErrorType.ABORT);
478483
break;
479484
}
480-
query = {text: `${OBSERVATION_PREFIX} ${result.output}`};
481-
// Capture history state for the next iteration query.
482-
request = this.buildRequest(query, Host.AidaClient.Role.USER);
483-
yield result;
484-
}
485-
} else if (functionCall) {
486-
try {
487-
const result = yield* this.#callFunction(functionCall.name, functionCall.args);
488-
489-
if (result.result) {
490-
yield {
491-
type: ResponseType.ACTION,
492-
output: JSON.stringify(result.result),
493-
canceled: false,
494-
};
495-
}
496-
497-
query = {
485+
query = this.functionCallEmulationEnabled ? {text: OBSERVATION_PREFIX + result.result} : {
498486
functionResponse: {
499487
name: functionCall.name,
500488
response: result,
501489
},
502490
};
503-
request = this.buildRequest(query, Host.AidaClient.Role.ROLE_UNSPECIFIED);
491+
request = this.buildRequest(
492+
query,
493+
this.functionCallEmulationEnabled ? Host.AidaClient.Role.USER : Host.AidaClient.Role.ROLE_UNSPECIFIED);
504494
} catch {
505495
yield this.#createErrorResponse(ErrorType.UNKNOWN);
506496
break;
@@ -524,18 +514,31 @@ export abstract class AiAgent<T> {
524514
if (!call) {
525515
throw new Error(`Function ${name} is not found.`);
526516
}
527-
this.#history.push({
528-
parts: [{
529-
functionCall: {
530-
name,
531-
args,
532-
},
533-
}],
534-
role: Host.AidaClient.Role.MODEL,
535-
});
517+
if (this.functionCallEmulationEnabled) {
518+
if (!call.displayInfoFromArgs) {
519+
throw new Error('functionCallEmulationEnabled requires all functions to provide displayInfoFromArgs');
520+
}
521+
// Emulated function calls are formatted as text.
522+
this.#history.push({
523+
parts: [{text: this.#formatParsedStep(call.displayInfoFromArgs(args))}],
524+
role: Host.AidaClient.Role.MODEL,
525+
});
526+
} else {
527+
this.#history.push({
528+
parts: [{
529+
functionCall: {
530+
name,
531+
args,
532+
},
533+
}],
534+
role: Host.AidaClient.Role.MODEL,
535+
});
536+
}
536537

538+
let code;
537539
if (call.displayInfoFromArgs) {
538-
const {title, thought, code, suggestions} = call.displayInfoFromArgs(args);
540+
const {title, thought, action: callCode} = call.displayInfoFromArgs(args);
541+
code = callCode;
539542
if (title) {
540543
yield {
541544
type: ResponseType.TITLE,
@@ -549,7 +552,11 @@ export abstract class AiAgent<T> {
549552
thought,
550553
};
551554
}
555+
}
552556

557+
let result = await call.handler(args, options) as FunctionCallHandlerResult<unknown>;
558+
559+
if ('requiresApproval' in result) {
553560
if (code) {
554561
yield {
555562
type: ResponseType.ACTION,
@@ -558,17 +565,6 @@ export abstract class AiAgent<T> {
558565
};
559566
}
560567

561-
if (suggestions) {
562-
yield {
563-
type: ResponseType.SUGGESTIONS,
564-
suggestions,
565-
};
566-
}
567-
}
568-
569-
let result = await call.handler(args, options);
570-
571-
if ('requiresApproval' in result) {
572568
const sideEffectConfirmationPromiseWithResolvers = this.confirmSideEffect<boolean>();
573569

574570
void sideEffectConfirmationPromiseWithResolvers.promise.then(result => {
@@ -597,7 +593,8 @@ export abstract class AiAgent<T> {
597593
if (!approvedRun) {
598594
yield {
599595
type: ResponseType.ACTION,
600-
code: '',
596+
code,
597+
output: 'Error: User denied code execution with side effects.',
601598
canceled: true,
602599
};
603600
return {
@@ -611,17 +608,30 @@ export abstract class AiAgent<T> {
611608
});
612609
}
613610

611+
if ('result' in result) {
612+
yield {
613+
type: ResponseType.ACTION,
614+
code,
615+
output: typeof result.result === 'string' ? result.result : JSON.stringify(result.result),
616+
canceled: false,
617+
};
618+
}
619+
620+
if ('error' in result) {
621+
yield {
622+
type: ResponseType.ACTION,
623+
code,
624+
output: result.error,
625+
canceled: false,
626+
};
627+
}
628+
614629
return result as {result: unknown};
615630
}
616631

617632
async *
618-
#aidaFetch(request: Host.AidaClient.AidaRequest, options?: {signal?: AbortSignal}): AsyncGenerator<
619-
{
620-
parsedResponse: ParsedResponse,
621-
functionCall?: Host.AidaClient.AidaFunctionCallResponse, completed: boolean,
622-
rpcId?: Host.AidaClient.RpcGlobalId,
623-
},
624-
void, void> {
633+
#aidaFetch(request: Host.AidaClient.AidaRequest, options?: {signal?: AbortSignal}):
634+
AsyncGenerator<AidaFetchResult, void, void> {
625635
let aidaResponse: Host.AidaClient.AidaResponse|undefined = undefined;
626636
let response = '';
627637
let rpcId: Host.AidaClient.RpcGlobalId|undefined;
@@ -631,19 +641,32 @@ export abstract class AiAgent<T> {
631641
debugLog('functionCalls.length', aidaResponse.functionCalls.length);
632642
yield {
633643
rpcId,
634-
parsedResponse: {answer: ''},
635644
functionCall: aidaResponse.functionCalls[0],
636645
completed: true,
637646
};
638647
break;
639648
}
640649

650+
if (this.functionCallEmulationEnabled) {
651+
const emulatedFunctionCall = this.emulateFunctionCall(aidaResponse);
652+
if (emulatedFunctionCall === 'wait-for-completion') {
653+
continue;
654+
}
655+
if (emulatedFunctionCall !== 'no-function-call') {
656+
yield {
657+
rpcId,
658+
functionCall: emulatedFunctionCall,
659+
completed: true,
660+
};
661+
break;
662+
}
663+
}
664+
641665
response = aidaResponse.explanation;
642666
rpcId = aidaResponse.metadata.rpcGlobalId ?? rpcId;
643-
const parsedResponse = this.parseResponse(aidaResponse);
644667
yield {
645668
rpcId,
646-
parsedResponse,
669+
text: aidaResponse.explanation,
647670
completed: aidaResponse.completed,
648671
};
649672
}

0 commit comments

Comments
 (0)