Skip to content

Commit 1838473

Browse files
fix(langchain): fix HITL middleware and allow to return to model with feedback (#9314)
1 parent e2c0b2e commit 1838473

File tree

3 files changed

+147
-6
lines changed

3 files changed

+147
-6
lines changed

libs/langchain/src/agents/middleware/hitl.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { interrupt } from "@langchain/langgraph";
99

1010
import { createMiddleware } from "../middleware.js";
1111
import type { AgentBuiltInState, Runtime } from "../runtime.js";
12+
import type { JumpToTarget } from "../constants.js";
1213

1314
const DescriptionFunctionSchema = z
1415
.function()
@@ -720,6 +721,9 @@ export function humanInTheLoopMiddleware(
720721

721722
const revisedToolCalls: ToolCall[] = [...autoApprovedToolCalls];
722723
const artificialToolMessages: ToolMessage[] = [];
724+
const hasRejectedToolCalls = decisions.some(
725+
(decision) => decision.type === "reject"
726+
);
723727

724728
/**
725729
* Process each decision using helper method
@@ -735,7 +739,15 @@ export function humanInTheLoopMiddleware(
735739
interruptConfig
736740
);
737741

738-
if (revisedToolCall) {
742+
if (
743+
revisedToolCall &&
744+
/**
745+
* If any decision is a rejected, we are going back to the model
746+
* with only the tool calls that were rejected as we don't know
747+
* the results of the approved/updated tool calls at this point.
748+
*/
749+
(!hasRejectedToolCalls || decision.type === "reject")
750+
) {
739751
revisedToolCalls.push(revisedToolCall);
740752
}
741753
if (toolMessage) {
@@ -750,7 +762,13 @@ export function humanInTheLoopMiddleware(
750762
lastMessage.tool_calls = revisedToolCalls;
751763
}
752764

753-
return { messages: [lastMessage, ...artificialToolMessages] };
765+
const jumpTo: JumpToTarget | undefined = hasRejectedToolCalls
766+
? "model"
767+
: undefined;
768+
return {
769+
messages: [lastMessage, ...artificialToolMessages],
770+
jumpTo,
771+
};
754772
},
755773
},
756774
});

libs/langchain/src/agents/middleware/summarization.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ export function summarizationMiddleware(
107107
) {
108108
return createMiddleware({
109109
name: "SummarizationMiddleware",
110-
contextSchema,
110+
contextSchema: contextSchema.extend({
111+
/**
112+
* `model` should be required when initializing the middleware,
113+
* but can be omitted within context when invoking the middleware.
114+
*/
115+
model: z.custom<BaseLanguageModel>().optional(),
116+
}),
111117
beforeModel: async (state, runtime) => {
112118
/**
113119
* Parse user options to get their explicit values

libs/langchain/src/agents/middleware/tests/hitl.int.test.ts

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ import { ChatOpenAI } from "@langchain/openai";
55
import { HumanMessage, AIMessage } from "@langchain/core/messages";
66
import { MemorySaver } from "@langchain/langgraph-checkpoint";
77
import { Command } from "@langchain/langgraph";
8+
import { ToolMessage } from "@langchain/core/messages";
89

910
import { tool } from "@langchain/core/tools";
10-
import { createAgent } from "../../index.js";
11+
import { createAgent, type Interrupt } from "../../index.js";
1112
import {
1213
type HITLRequest,
1314
type HITLResponse,
@@ -299,7 +300,6 @@ describe("humanInTheLoopMiddleware", () => {
299300
decisions: [
300301
{
301302
type: "reject",
302-
message: "The calculation result is 500 (custom override)",
303303
},
304304
{
305305
type: "approve",
@@ -309,10 +309,127 @@ describe("humanInTheLoopMiddleware", () => {
309309
}),
310310
thread
311311
);
312-
expect(resume.structuredResponse).toEqual({
312+
313+
/**
314+
* we expect another interrupt as model updates the tool call
315+
*/
316+
expect("__interrupt__" in resume).toBe(true);
317+
318+
const lastMessage = resume.messages.at(-1) as AIMessage;
319+
const finalResume = await agent.invoke(
320+
new Command({
321+
resume: {
322+
decisions:
323+
lastMessage.tool_calls?.map(() => ({
324+
type: "approve",
325+
})) ?? [],
326+
} satisfies HITLResponse,
327+
}),
328+
thread
329+
);
330+
331+
expect(finalResume.structuredResponse).toEqual({
313332
result: expect.toBeOneOf([500, 579]),
314333
name: "Thomas",
315334
});
335+
/**
336+
* we expect the final resume to have 8 messages:
337+
* 1. human message
338+
* 2. AI message with 2 calls
339+
* 3. Rejected tool message
340+
* 4. new tool call
341+
* 5. approved tool message
342+
* 6. approved tool message
343+
* 7. AI message with final response
344+
*/
345+
expect(finalResume.messages).toHaveLength(7);
346+
});
347+
348+
it("should allow to reject tool calls and give model feedback", async () => {
349+
const checkpointer = new MemorySaver();
350+
const sendEmailTool = tool(
351+
() => {
352+
return "Email sent!";
353+
},
354+
{
355+
name: "send_email",
356+
description: "Sends an email",
357+
schema: z.object({
358+
message: z.string(),
359+
to: z.array(z.string()),
360+
subject: z.string(),
361+
}),
362+
}
363+
);
364+
const agent = createAgent({
365+
model,
366+
middleware: [
367+
humanInTheLoopMiddleware({
368+
interruptOn: {
369+
send_email: true,
370+
},
371+
}),
372+
] as const,
373+
tools: [sendEmailTool],
374+
checkpointer,
375+
});
376+
377+
const result = await agent.invoke(
378+
{
379+
messages: [
380+
new HumanMessage(
381+
"Send an email to [email protected], saying hello!"
382+
),
383+
],
384+
},
385+
thread
386+
);
387+
388+
/**
389+
* first interception
390+
*/
391+
expect("__interrupt__" in result).toBe(true);
392+
const resume = await agent.invoke(
393+
new Command({
394+
resume: {
395+
decisions: [
396+
{
397+
type: "reject",
398+
message:
399+
"Send the email speaking like a pirate starting the message with 'Arrr, matey!'",
400+
},
401+
],
402+
} satisfies HITLResponse,
403+
}),
404+
thread
405+
);
406+
407+
/**
408+
* second interception, verify model as updated the tool call and approve
409+
*/
410+
const interrupt = resume.__interrupt__?.[0] as Interrupt<HITLRequest>;
411+
expect(
412+
interrupt?.value?.actionRequests[0].args.message.startsWith(
413+
"Arrr, matey!"
414+
)
415+
).toBe(true);
416+
const finalResume = await agent.invoke(
417+
new Command({
418+
resume: {
419+
decisions: [
420+
{
421+
type: "approve",
422+
},
423+
],
424+
} satisfies HITLResponse,
425+
}),
426+
thread
427+
);
428+
const toolMessage = [...finalResume.messages]
429+
.reverse()
430+
.find(ToolMessage.isInstance);
431+
expect(toolMessage).toBeDefined();
432+
expect(toolMessage?.content).toBe("Email sent!");
316433
});
317434
});
318435
});

0 commit comments

Comments
 (0)