diff --git a/libs/langchain/src/agents/middleware/hitl.ts b/libs/langchain/src/agents/middleware/hitl.ts index 738c5a34d485..c8162ce365f6 100644 --- a/libs/langchain/src/agents/middleware/hitl.ts +++ b/libs/langchain/src/agents/middleware/hitl.ts @@ -236,9 +236,19 @@ const contextSchema = z.object({ * - `true` -> pause for approval and allow approve/edit/reject decisions * - `false` -> auto-approve (no human review) * - `InterruptOnConfig` -> explicitly specify which decisions are allowed for this tool + * - `(toolCall: ToolCall) => InterruptOnConfig | boolean` -> conditionally interrupt on a tool call */ interruptOn: z - .record(z.union([z.boolean(), InterruptOnConfigSchema])) + .record( + z.union([ + z.boolean(), + InterruptOnConfigSchema, + z + .function() + .args(z.custom()) + .returns(z.union([z.boolean(), InterruptOnConfigSchema])), + ]) + ) .optional(), /** * Prefix used when constructing human-facing approval messages. @@ -452,6 +462,41 @@ export type HumanInTheLoopMiddlewareConfig = InferInteropZodInput< * }); * ``` * + * @example + * Using conditional interrupt functions + * ```typescript + * import { type ToolCall, type InterruptOnConfig } from "langchain"; + * + * // Define a conditional function that decides whether to interrupt + * // based on tool call arguments + * const conditionalInterrupt = (toolCall: ToolCall): boolean | InterruptOnConfig => { + * const filename = toolCall.args.filename as string; + * // Only interrupt if filename contains "dangerous" or "admin" + * if (filename.includes("dangerous") || filename.includes("admin")) { + * return { allowedDecisions: ["approve", "edit"] }; + * } + * // Auto-approve safe files + * return false; + * }; + * + * const hitlMiddleware = humanInTheLoopMiddleware({ + * interruptOn: { + * // Use a function to conditionally interrupt based on tool call arguments + * "write_file": conditionalInterrupt, + * // Or use an inline function + * "execute_sql": (toolCall: ToolCall) => { + * const query = toolCall.args.query as string; + * // Only interrupt for write operations (INSERT, UPDATE, DELETE) + * if (/^\s*(INSERT|UPDATE|DELETE)/i.test(query)) { + * return { allowedDecisions: ["approve", "reject"] }; + * } + * // Auto-approve SELECT queries + * return false; + * } + * } + * }); + * ``` + * * @remarks * - Tool calls are processed in the order they appear in the AI message * - Auto-approved tools execute immediately without interruption @@ -633,16 +678,27 @@ export function humanInTheLoopMiddleware( * Resolve per-tool configs (boolean true -> all decisions allowed; false -> auto-approve) */ const resolvedConfigs: Record = {}; - for (const [toolName, toolConfig] of Object.entries( + for (const [toolName, toolConfigOrToolConfigFactory] of Object.entries( config.interruptOn )) { + const toolCall = lastMessage.tool_calls.find( + (toolCall) => toolCall.name === toolName + ); + + const toolConfig = + typeof toolConfigOrToolConfigFactory === "function" + ? toolCall + ? await toolConfigOrToolConfigFactory(toolCall) + : undefined + : toolConfigOrToolConfigFactory; + if (typeof toolConfig === "boolean") { if (toolConfig === true) { resolvedConfigs[toolName] = { allowedDecisions: [...ALLOWED_DECISIONS], }; } - } else if (toolConfig.allowedDecisions) { + } else if (toolConfig?.allowedDecisions) { resolvedConfigs[toolName] = toolConfig as InterruptOnConfig; } } diff --git a/libs/langchain/src/agents/middleware/tests/hitl.test.ts b/libs/langchain/src/agents/middleware/tests/hitl.test.ts index 96c01fe8c054..7ca4f945b437 100644 --- a/libs/langchain/src/agents/middleware/tests/hitl.test.ts +++ b/libs/langchain/src/agents/middleware/tests/hitl.test.ts @@ -1,13 +1,19 @@ import { z } from "zod/v3"; import { describe, it, expect, vi, beforeEach } from "vitest"; import { tool } from "@langchain/core/tools"; -import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; +import { + AIMessage, + HumanMessage, + ToolMessage, + ToolCall, +} from "@langchain/core/messages"; import { Command } from "@langchain/langgraph"; import { MemorySaver } from "@langchain/langgraph-checkpoint"; import { createAgent } from "../../index.js"; import { humanInTheLoopMiddleware, + type InterruptOnConfig, type HITLRequest, type HITLResponse, type Decision, @@ -691,6 +697,182 @@ describe("humanInTheLoopMiddleware", () => { ); }); + it("should support conditional interrupt on", async () => { + // Create a conditional function that only interrupts for dangerous files + const conditionalInterruptFn = ( + toolCall: ToolCall + ): boolean | InterruptOnConfig => { + const filename = toolCall.args.filename as string; + // Only interrupt if filename contains "dangerous" + if (filename.includes("dangerous")) { + return { allowedDecisions: ["approve"] }; + } + // Auto-approve safe files + return false; + }; + const conditionalInterrupt = vi.fn(conditionalInterruptFn); + + const hitlMiddleware = humanInTheLoopMiddleware({ + interruptOn: { + write_file: conditionalInterrupt, + }, + }); + + const model = new FakeToolCallingModel({ + toolCalls: [ + // First call: dangerous file (should interrupt) + [ + { + id: "call_1", + name: "write_file", + args: { + filename: "dangerous.txt", + content: "Dangerous content", + }, + }, + ], + // Second call: safe file (should auto-approve) + [ + { + id: "call_2", + name: "write_file", + args: { + filename: "safe.txt", + content: "Safe content", + }, + }, + ], + [], + ], + }); + + const checkpointer = new MemorySaver(); + const agent = createAgent({ + model, + checkpointer, + tools: [writeFileTool], + middleware: [hitlMiddleware], + }); + + const config = { + configurable: { + thread_id: "test-conditional", + }, + }; + + // Test 1: Dangerous file should interrupt + model.index = 0; + await agent.invoke( + { + messages: [new HumanMessage("Write to dangerous file")], + }, + config + ); + + // Verify conditional function was called + expect(conditionalInterrupt).toHaveBeenCalledTimes(1); + expect(conditionalInterrupt).toHaveBeenCalledWith( + expect.objectContaining({ + id: "call_1", + name: "write_file", + args: { + filename: "dangerous.txt", + content: "Dangerous content", + }, + }) + ); + + // Verify write_file was NOT called yet (interrupted) + expect(writeFileFn).not.toHaveBeenCalled(); + + // Check if agent is paused for approval + const state = await agent.graph.getState(config); + expect(state.next).toBeDefined(); + expect(state.next.length).toBe(1); + + // Verify interrupt data + const task = state.tasks?.[0]; + expect(task).toBeDefined(); + expect(task.interrupts).toBeDefined(); + expect(task.interrupts.length).toBe(1); + + const hitlRequest = task.interrupts[0].value as HITLRequest; + expect(hitlRequest.actionRequests).toHaveLength(1); + expect(hitlRequest.actionRequests[0]).toEqual( + expect.objectContaining({ + name: "write_file", + args: { + filename: "dangerous.txt", + content: "Dangerous content", + }, + }) + ); + expect(hitlRequest.reviewConfigs[0].allowedDecisions).toEqual(["approve"]); + + // Resume with approval + model.index = 0; + await agent.invoke( + new Command({ + resume: { decisions: [{ type: "approve" }] } as HITLResponse, + }), + config + ); + + // Verify write_file was called after approval + expect(writeFileFn).toHaveBeenCalledTimes(1); + expect(writeFileFn).toHaveBeenCalledWith( + { + filename: "dangerous.txt", + content: "Dangerous content", + }, + expect.anything() + ); + + // Test 2: Safe file should auto-approve (no interrupt) + model.index = 1; + writeFileFn.mockClear(); + conditionalInterrupt.mockClear(); + + const safeResult = await agent.invoke( + { + messages: [new HumanMessage("Write to safe file")], + }, + config + ); + + // Verify conditional function was called + expect(conditionalInterrupt).toHaveBeenCalledTimes(1); + expect(conditionalInterrupt).toHaveBeenCalledWith( + expect.objectContaining({ + id: "call_2", + name: "write_file", + args: { + filename: "safe.txt", + content: "Safe content", + }, + }) + ); + + // Verify write_file was called immediately (auto-approved) + expect(writeFileFn).toHaveBeenCalledTimes(1); + expect(writeFileFn).toHaveBeenCalledWith( + { + filename: "safe.txt", + content: "Safe content", + }, + expect.anything() + ); + + // Verify no interrupt occurred + expect("__interrupt__" in safeResult).toBe(false); + + // Verify final response + const safeMessages = safeResult.messages; + expect(safeMessages[safeMessages.length - 1].content).toContain( + "Successfully wrote 12 characters to safe.txt" + ); + }); + it("should support dynamic description factory functions", async () => { // Create a description factory that formats based on tool call details const descriptionFactory = vi.fn((toolCall, _state, _runtime) => {