Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions libs/langchain/src/agents/middleware/hitl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,19 @@ const contextSchema = z.object({
* - `true` -> pause for approval and allow approve/edit/reject decisions
* - `false` -> auto-approve (no human review)
* - `InterruptOnConfig` -> explicitly specify which decisions are allowed for this tool
* - `(toolCall: ToolCall) => InterruptOnConfig | boolean` -> conditionally interrupt on a tool call
*/
interruptOn: z
.record(z.union([z.boolean(), InterruptOnConfigSchema]))
.record(
z.union([
z.boolean(),
InterruptOnConfigSchema,
z
.function()
.args(z.custom<ToolCall>())
.returns(z.union([z.boolean(), InterruptOnConfigSchema])),
])
)
.optional(),
/**
* Prefix used when constructing human-facing approval messages.
Expand Down Expand Up @@ -452,6 +462,41 @@ export type HumanInTheLoopMiddlewareConfig = InferInteropZodInput<
* });
* ```
*
* @example
* Using conditional interrupt functions
* ```typescript
* import { type ToolCall, type InterruptOnConfig } from "langchain";
*
* // Define a conditional function that decides whether to interrupt
* // based on tool call arguments
* const conditionalInterrupt = (toolCall: ToolCall): boolean | InterruptOnConfig => {
* const filename = toolCall.args.filename as string;
* // Only interrupt if filename contains "dangerous" or "admin"
* if (filename.includes("dangerous") || filename.includes("admin")) {
* return { allowedDecisions: ["approve", "edit"] };
* }
* // Auto-approve safe files
* return false;
* };
*
* const hitlMiddleware = humanInTheLoopMiddleware({
* interruptOn: {
* // Use a function to conditionally interrupt based on tool call arguments
* "write_file": conditionalInterrupt,
* // Or use an inline function
* "execute_sql": (toolCall: ToolCall) => {
* const query = toolCall.args.query as string;
* // Only interrupt for write operations (INSERT, UPDATE, DELETE)
* if (/^\s*(INSERT|UPDATE|DELETE)/i.test(query)) {
* return { allowedDecisions: ["approve", "reject"] };
* }
* // Auto-approve SELECT queries
* return false;
* }
* }
* });
* ```
*
* @remarks
* - Tool calls are processed in the order they appear in the AI message
* - Auto-approved tools execute immediately without interruption
Expand Down Expand Up @@ -633,16 +678,27 @@ export function humanInTheLoopMiddleware(
* Resolve per-tool configs (boolean true -> all decisions allowed; false -> auto-approve)
*/
const resolvedConfigs: Record<string, InterruptOnConfig> = {};
for (const [toolName, toolConfig] of Object.entries(
for (const [toolName, toolConfigOrToolConfigFactory] of Object.entries(
config.interruptOn
)) {
const toolCall = lastMessage.tool_calls.find(
(toolCall) => toolCall.name === toolName
);

const toolConfig =
typeof toolConfigOrToolConfigFactory === "function"
? toolCall
? await toolConfigOrToolConfigFactory(toolCall)
: undefined
: toolConfigOrToolConfigFactory;

if (typeof toolConfig === "boolean") {
if (toolConfig === true) {
resolvedConfigs[toolName] = {
allowedDecisions: [...ALLOWED_DECISIONS],
};
}
} else if (toolConfig.allowedDecisions) {
} else if (toolConfig?.allowedDecisions) {
resolvedConfigs[toolName] = toolConfig as InterruptOnConfig;
}
}
Expand Down
184 changes: 183 additions & 1 deletion libs/langchain/src/agents/middleware/tests/hitl.test.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import { z } from "zod/v3";
import { describe, it, expect, vi, beforeEach } from "vitest";
import { tool } from "@langchain/core/tools";
import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages";
import {
AIMessage,
HumanMessage,
ToolMessage,
ToolCall,
} from "@langchain/core/messages";
import { Command } from "@langchain/langgraph";
import { MemorySaver } from "@langchain/langgraph-checkpoint";

import { createAgent } from "../../index.js";
import {
humanInTheLoopMiddleware,
type InterruptOnConfig,
type HITLRequest,
type HITLResponse,
type Decision,
Expand Down Expand Up @@ -691,6 +697,182 @@ describe("humanInTheLoopMiddleware", () => {
);
});

it("should support conditional interrupt on", async () => {
// Create a conditional function that only interrupts for dangerous files
const conditionalInterruptFn = (
toolCall: ToolCall
): boolean | InterruptOnConfig => {
const filename = toolCall.args.filename as string;
// Only interrupt if filename contains "dangerous"
if (filename.includes("dangerous")) {
return { allowedDecisions: ["approve"] };
}
// Auto-approve safe files
return false;
};
const conditionalInterrupt = vi.fn(conditionalInterruptFn);

const hitlMiddleware = humanInTheLoopMiddleware({
interruptOn: {
write_file: conditionalInterrupt,
},
});

const model = new FakeToolCallingModel({
toolCalls: [
// First call: dangerous file (should interrupt)
[
{
id: "call_1",
name: "write_file",
args: {
filename: "dangerous.txt",
content: "Dangerous content",
},
},
],
// Second call: safe file (should auto-approve)
[
{
id: "call_2",
name: "write_file",
args: {
filename: "safe.txt",
content: "Safe content",
},
},
],
[],
],
});

const checkpointer = new MemorySaver();
const agent = createAgent({
model,
checkpointer,
tools: [writeFileTool],
middleware: [hitlMiddleware],
});

const config = {
configurable: {
thread_id: "test-conditional",
},
};

// Test 1: Dangerous file should interrupt
model.index = 0;
await agent.invoke(
{
messages: [new HumanMessage("Write to dangerous file")],
},
config
);

// Verify conditional function was called
expect(conditionalInterrupt).toHaveBeenCalledTimes(1);
expect(conditionalInterrupt).toHaveBeenCalledWith(
expect.objectContaining({
id: "call_1",
name: "write_file",
args: {
filename: "dangerous.txt",
content: "Dangerous content",
},
})
);

// Verify write_file was NOT called yet (interrupted)
expect(writeFileFn).not.toHaveBeenCalled();

// Check if agent is paused for approval
const state = await agent.graph.getState(config);
expect(state.next).toBeDefined();
expect(state.next.length).toBe(1);

// Verify interrupt data
const task = state.tasks?.[0];
expect(task).toBeDefined();
expect(task.interrupts).toBeDefined();
expect(task.interrupts.length).toBe(1);

const hitlRequest = task.interrupts[0].value as HITLRequest;
expect(hitlRequest.actionRequests).toHaveLength(1);
expect(hitlRequest.actionRequests[0]).toEqual(
expect.objectContaining({
name: "write_file",
args: {
filename: "dangerous.txt",
content: "Dangerous content",
},
})
);
expect(hitlRequest.reviewConfigs[0].allowedDecisions).toEqual(["approve"]);

// Resume with approval
model.index = 0;
await agent.invoke(
new Command({
resume: { decisions: [{ type: "approve" }] } as HITLResponse,
}),
config
);

// Verify write_file was called after approval
expect(writeFileFn).toHaveBeenCalledTimes(1);
expect(writeFileFn).toHaveBeenCalledWith(
{
filename: "dangerous.txt",
content: "Dangerous content",
},
expect.anything()
);

// Test 2: Safe file should auto-approve (no interrupt)
model.index = 1;
writeFileFn.mockClear();
conditionalInterrupt.mockClear();

const safeResult = await agent.invoke(
{
messages: [new HumanMessage("Write to safe file")],
},
config
);

// Verify conditional function was called
expect(conditionalInterrupt).toHaveBeenCalledTimes(1);
expect(conditionalInterrupt).toHaveBeenCalledWith(
expect.objectContaining({
id: "call_2",
name: "write_file",
args: {
filename: "safe.txt",
content: "Safe content",
},
})
);

// Verify write_file was called immediately (auto-approved)
expect(writeFileFn).toHaveBeenCalledTimes(1);
expect(writeFileFn).toHaveBeenCalledWith(
{
filename: "safe.txt",
content: "Safe content",
},
expect.anything()
);

// Verify no interrupt occurred
expect("__interrupt__" in safeResult).toBe(false);

// Verify final response
const safeMessages = safeResult.messages;
expect(safeMessages[safeMessages.length - 1].content).toContain(
"Successfully wrote 12 characters to safe.txt"
);
});

it("should support dynamic description factory functions", async () => {
// Create a description factory that formats based on tool call details
const descriptionFactory = vi.fn((toolCall, _state, _runtime) => {
Expand Down
Loading