Skip to content

Commit 8998200

Browse files
Edilmomboshernitsan
authored andcommitted
feat(hooks): Hook Agent Lifecycle Integration (#9105)
1 parent 151bf9d commit 8998200

File tree

12 files changed

+631
-3
lines changed

12 files changed

+631
-3
lines changed

packages/a2a-server/src/utils/testing_utils.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ import {
1414
DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
1515
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
1616
GeminiClient,
17+
HookSystem,
1718
} from '@google/gemini-cli-core';
19+
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
1820
import type { Config, Storage } from '@google/gemini-cli-core';
1921
import { expect, vi } from 'vitest';
2022

@@ -54,8 +56,13 @@ export function createMockConfig(
5456
getMessageBus: vi.fn(),
5557
getPolicyEngine: vi.fn(),
5658
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
59+
getEnableHooks: vi.fn().mockReturnValue(false),
5760
...overrides,
5861
} as unknown as Config;
62+
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
63+
mockConfig.getHookSystem = vi
64+
.fn()
65+
.mockReturnValue(new HookSystem(mockConfig));
5966

6067
mockConfig.getGeminiClient = vi
6168
.fn()

packages/cli/src/ui/hooks/useToolScheduler.test.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ import {
3232
ToolConfirmationOutcome,
3333
ApprovalMode,
3434
MockTool,
35+
HookSystem,
3536
} from '@google/gemini-cli-core';
37+
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
3638
import { ToolCallStatus } from '../types.js';
3739

3840
// Mocks
@@ -81,7 +83,10 @@ const mockConfig = {
8183
getPolicyEngine: () => null,
8284
isInteractive: () => false,
8385
getExperiments: () => {},
86+
getEnableHooks: () => false,
8487
} as unknown as Config;
88+
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
89+
mockConfig.getHookSystem = vi.fn().mockReturnValue(new HookSystem(mockConfig));
8590

8691
const mockTool = new MockTool({
8792
name: 'mockTool',

packages/core/src/config/config.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ import type { EventEmitter } from 'node:events';
7676
import { MessageBus } from '../confirmation-bus/message-bus.js';
7777
import { PolicyEngine } from '../policy/policy-engine.js';
7878
import type { PolicyEngineConfig } from '../policy/types.js';
79+
import { HookSystem } from '../hooks/index.js';
7980
import type { UserTierId } from '../code_assist/types.js';
8081
import { getCodeAssistServer } from '../code_assist/codeAssist.js';
8182
import type { Experiments } from '../code_assist/experiments/experiments.js';
@@ -416,6 +417,7 @@ export class Config {
416417
| undefined;
417418
private experiments: Experiments | undefined;
418419
private experimentsPromise: Promise<void> | undefined;
420+
private hookSystem?: HookSystem;
419421

420422
private previewModelFallbackMode = false;
421423
private previewModelBypassMode = false;
@@ -629,6 +631,12 @@ export class Config {
629631
await this.getExtensionLoader().start(this),
630632
]);
631633

634+
// Initialize hook system if enabled
635+
if (this.enableHooks) {
636+
this.hookSystem = new HookSystem(this);
637+
await this.hookSystem.initialize();
638+
}
639+
632640
await this.geminiClient.initialize();
633641
}
634642

@@ -1481,6 +1489,13 @@ export class Config {
14811489
return registry;
14821490
}
14831491

1492+
/**
1493+
* Get the hook system instance
1494+
*/
1495+
getHookSystem(): HookSystem | undefined {
1496+
return this.hookSystem;
1497+
}
1498+
14841499
/**
14851500
* Get hooks configuration
14861501
*/

packages/core/src/core/client.test.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import type {
4343
ResolvedModelConfig,
4444
} from '../services/modelConfigService.js';
4545
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
46+
import { HookSystem } from '../hooks/hookSystem.js';
4647

4748
vi.mock('../services/chatCompressionService.js');
4849

@@ -120,6 +121,7 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
120121
getLastPromptTokenCount: vi.fn(),
121122
},
122123
}));
124+
vi.mock('../hooks/hookSystem.js');
123125

124126
/**
125127
* Array.fromAsync ponyfill, which will be available in es 2024.
@@ -211,6 +213,8 @@ describe('Gemini Client (client.ts)', () => {
211213
getModelRouterService: vi.fn().mockReturnValue({
212214
route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }),
213215
}),
216+
getMessageBus: vi.fn().mockReturnValue(undefined),
217+
getEnableHooks: vi.fn().mockReturnValue(false),
214218
isInFallbackMode: vi.fn().mockReturnValue(false),
215219
setFallbackMode: vi.fn(),
216220
getChatCompression: vi.fn().mockReturnValue(undefined),
@@ -243,6 +247,9 @@ describe('Gemini Client (client.ts)', () => {
243247
isInteractive: vi.fn().mockReturnValue(false),
244248
getExperiments: () => {},
245249
} as unknown as Config;
250+
mockConfig.getHookSystem = vi
251+
.fn()
252+
.mockReturnValue(new HookSystem(mockConfig));
246253

247254
client = new GeminiClient(mockConfig);
248255
await client.initialize();

packages/core/src/core/client.ts

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ import {
4242
logContentRetryFailure,
4343
logNextSpeakerCheck,
4444
} from '../telemetry/loggers.js';
45+
import {
46+
fireBeforeAgentHook,
47+
fireAfterAgentHook,
48+
} from './clientHookTriggers.js';
4549
import {
4650
ContentRetryFailureEvent,
4751
NextSpeakerCheckEvent,
@@ -438,6 +442,35 @@ export class GeminiClient {
438442
turns: number = MAX_TURNS,
439443
isInvalidStreamRetry: boolean = false,
440444
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
445+
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
446+
const hooksEnabled = this.config.getEnableHooks();
447+
const messageBus = this.config.getMessageBus();
448+
if (hooksEnabled && messageBus) {
449+
const hookOutput = await fireBeforeAgentHook(messageBus, request);
450+
451+
if (
452+
hookOutput?.isBlockingDecision() ||
453+
hookOutput?.shouldStopExecution()
454+
) {
455+
yield {
456+
type: GeminiEventType.Error,
457+
value: {
458+
error: new Error(
459+
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
460+
),
461+
},
462+
};
463+
return new Turn(this.getChat(), prompt_id);
464+
}
465+
466+
// Add additional context from hooks to the request
467+
const additionalContext = hookOutput?.getAdditionalContext();
468+
if (additionalContext) {
469+
const requestArray = Array.isArray(request) ? request : [request];
470+
request = [...requestArray, { text: additionalContext }];
471+
}
472+
}
473+
441474
if (this.lastPromptId !== prompt_id) {
442475
this.loopDetector.reset(prompt_id);
443476
this.lastPromptId = prompt_id;
@@ -608,9 +641,9 @@ export class GeminiClient {
608641
);
609642
if (nextSpeakerCheck?.next_speaker === 'model') {
610643
const nextRequest = [{ text: 'Please continue.' }];
611-
// This recursive call's events will be yielded out, but the final
612-
// turn object will be from the top-level call.
613-
yield* this.sendMessageStream(
644+
// This recursive call's events will be yielded out, and the final
645+
// turn object from the recursive call will be returned.
646+
return yield* this.sendMessageStream(
614647
nextRequest,
615648
signal,
616649
prompt_id,
@@ -619,6 +652,32 @@ export class GeminiClient {
619652
);
620653
}
621654
}
655+
656+
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
657+
if (hooksEnabled && messageBus) {
658+
const responseText = turn.getResponseText() || '[no response text]';
659+
const hookOutput = await fireAfterAgentHook(
660+
messageBus,
661+
request,
662+
responseText,
663+
);
664+
665+
// For AfterAgent hooks, blocking/stop execution should force continuation
666+
if (
667+
hookOutput?.isBlockingDecision() ||
668+
hookOutput?.shouldStopExecution()
669+
) {
670+
const continueReason = hookOutput.getEffectiveReason();
671+
const continueRequest = [{ text: continueReason }];
672+
yield* this.sendMessageStream(
673+
continueRequest,
674+
signal,
675+
prompt_id,
676+
boundedTurns - 1,
677+
);
678+
}
679+
}
680+
622681
return turn;
623682
}
624683

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Google LLC
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
import type { PartListUnion } from '@google/genai';
8+
import type { MessageBus } from '../confirmation-bus/message-bus.js';
9+
import {
10+
MessageBusType,
11+
type HookExecutionRequest,
12+
type HookExecutionResponse,
13+
} from '../confirmation-bus/types.js';
14+
import { createHookOutput, type DefaultHookOutput } from '../hooks/types.js';
15+
import { partToString } from '../utils/partUtils.js';
16+
import { debugLogger } from '../utils/debugLogger.js';
17+
18+
/**
19+
* Fires the BeforeAgent hook and returns the hook output.
20+
* This should be called before processing a user prompt.
21+
*
22+
* The caller can use the returned DefaultHookOutput methods:
23+
* - isBlockingDecision() / shouldStopExecution() to check if blocked
24+
* - getEffectiveReason() to get the blocking reason
25+
* - getAdditionalContext() to get additional context to add
26+
*
27+
* @param messageBus The message bus to use for hook communication
28+
* @param request The user's request (prompt)
29+
* @returns The hook output, or undefined if no hook was executed or on error
30+
*/
31+
export async function fireBeforeAgentHook(
32+
messageBus: MessageBus,
33+
request: PartListUnion,
34+
): Promise<DefaultHookOutput | undefined> {
35+
try {
36+
const promptText = partToString(request);
37+
38+
const response = await messageBus.request<
39+
HookExecutionRequest,
40+
HookExecutionResponse
41+
>(
42+
{
43+
type: MessageBusType.HOOK_EXECUTION_REQUEST,
44+
eventName: 'BeforeAgent',
45+
input: {
46+
prompt: promptText,
47+
},
48+
},
49+
MessageBusType.HOOK_EXECUTION_RESPONSE,
50+
);
51+
52+
return response.output
53+
? createHookOutput('BeforeAgent', response.output)
54+
: undefined;
55+
} catch (error) {
56+
debugLogger.warn(`BeforeAgent hook failed: ${error}`);
57+
return undefined;
58+
}
59+
}
60+
61+
/**
62+
* Fires the AfterAgent hook and returns the hook output.
63+
* This should be called after the agent has generated a response.
64+
*
65+
* The caller can use the returned DefaultHookOutput methods:
66+
* - isBlockingDecision() / shouldStopExecution() to check if continuation is requested
67+
* - getEffectiveReason() to get the continuation reason
68+
*
69+
* @param messageBus The message bus to use for hook communication
70+
* @param request The original user's request (prompt)
71+
* @param responseText The agent's response text
72+
* @returns The hook output, or undefined if no hook was executed or on error
73+
*/
74+
export async function fireAfterAgentHook(
75+
messageBus: MessageBus,
76+
request: PartListUnion,
77+
responseText: string,
78+
): Promise<DefaultHookOutput | undefined> {
79+
try {
80+
const promptText = partToString(request);
81+
82+
const response = await messageBus.request<
83+
HookExecutionRequest,
84+
HookExecutionResponse
85+
>(
86+
{
87+
type: MessageBusType.HOOK_EXECUTION_REQUEST,
88+
eventName: 'AfterAgent',
89+
input: {
90+
prompt: promptText,
91+
prompt_response: responseText,
92+
stop_hook_active: false,
93+
},
94+
},
95+
MessageBusType.HOOK_EXECUTION_RESPONSE,
96+
);
97+
98+
return response.output
99+
? createHookOutput('AfterAgent', response.output)
100+
: undefined;
101+
} catch (error) {
102+
debugLogger.warn(`AfterAgent hook failed: ${error}`);
103+
return undefined;
104+
}
105+
}

packages/core/src/core/nonInteractiveToolExecutor.test.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ import {
1818
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
1919
ToolErrorType,
2020
ApprovalMode,
21+
HookSystem,
2122
} from '../index.js';
2223
import type { Part } from '@google/genai';
2324
import { MockTool } from '../test-utils/mock-tool.js';
25+
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
2426

2527
describe('executeToolCall', () => {
2628
let mockToolRegistry: ToolRegistry;
@@ -66,8 +68,15 @@ describe('executeToolCall', () => {
6668
getPolicyEngine: () => null,
6769
isInteractive: () => false,
6870
getExperiments: () => {},
71+
getEnableHooks: () => false,
6972
} as unknown as Config;
7073

74+
// Use proper MessageBus mocking for Phase 3 preparation
75+
const mockMessageBus = createMockMessageBus();
76+
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
77+
mockConfig.getHookSystem = vi
78+
.fn()
79+
.mockReturnValue(new HookSystem(mockConfig));
7180
abortController = new AbortController();
7281
});
7382

packages/core/src/core/turn.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,17 @@ export class Turn {
392392
getDebugResponses(): GenerateContentResponse[] {
393393
return this.debugResponses;
394394
}
395+
396+
/**
397+
* Get the concatenated response text from all responses in this turn.
398+
* This extracts and joins all text content from the model's responses.
399+
*/
400+
getResponseText(): string {
401+
return this.debugResponses
402+
.map((response) => getResponseText(response))
403+
.filter((text): text is string => text !== null)
404+
.join(' ');
405+
}
395406
}
396407

397408
function getCitations(resp: GenerateContentResponse): string[] {

0 commit comments

Comments
 (0)