Skip to content

Commit 5d8b75f

Browse files
fix(langchain): remove _privateState (langchain-ai#9289)
1 parent 66fc10c commit 5d8b75f

File tree

13 files changed

+197
-461
lines changed

13 files changed

+197
-461
lines changed

libs/langchain/src/agents/ReactAgent.ts

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,7 @@ import type {
4646
JumpTo,
4747
UserInput,
4848
} from "./types.js";
49-
import type {
50-
PrivateState,
51-
InvokeConfiguration,
52-
StreamConfiguration,
53-
} from "./runtime.js";
49+
import type { InvokeConfiguration, StreamConfiguration } from "./runtime.js";
5450
import type {
5551
AgentMiddleware,
5652
InferMiddlewareContextInputs,
@@ -249,10 +245,20 @@ export class ReactAgent<
249245
throw new Error(`Middleware ${m.name} is defined multiple times`);
250246
}
251247

248+
const getState = () => {
249+
return {
250+
...beforeAgentNode?.getState(),
251+
...beforeModelNode?.getState(),
252+
...afterModelNode?.getState(),
253+
...afterAgentNode?.getState(),
254+
...this.#agentNode.getState(),
255+
};
256+
};
257+
252258
middlewareNames.add(m.name);
253259
if (m.beforeAgent) {
254260
beforeAgentNode = new BeforeAgentNode(m, {
255-
getPrivateState: () => this.#agentNode.getState()._privateState,
261+
getState,
256262
});
257263
const name = `${m.name}.before_agent`;
258264
beforeAgentNodes.push({
@@ -268,7 +274,7 @@ export class ReactAgent<
268274
}
269275
if (m.beforeModel) {
270276
beforeModelNode = new BeforeModelNode(m, {
271-
getPrivateState: () => this.#agentNode.getState()._privateState,
277+
getState,
272278
});
273279
const name = `${m.name}.before_model`;
274280
beforeModelNodes.push({
@@ -284,7 +290,7 @@ export class ReactAgent<
284290
}
285291
if (m.afterModel) {
286292
afterModelNode = new AfterModelNode(m, {
287-
getPrivateState: () => this.#agentNode.getState()._privateState,
293+
getState,
288294
});
289295
const name = `${m.name}.after_model`;
290296
afterModelNodes.push({
@@ -300,7 +306,7 @@ export class ReactAgent<
300306
}
301307
if (m.afterAgent) {
302308
afterAgentNode = new AfterAgentNode(m, {
303-
getPrivateState: () => this.#agentNode.getState()._privateState,
309+
getState,
304310
});
305311
const name = `${m.name}.after_agent`;
306312
afterAgentNodes.push({
@@ -316,15 +322,7 @@ export class ReactAgent<
316322
}
317323

318324
if (m.wrapModelCall) {
319-
wrapModelCallHookMiddleware.push([
320-
m,
321-
() => ({
322-
...beforeAgentNode?.getState(),
323-
...beforeModelNode?.getState(),
324-
...afterModelNode?.getState(),
325-
...afterAgentNode?.getState(),
326-
}),
327-
]);
325+
wrapModelCallHookMiddleware.push([m, getState]);
328326
}
329327
}
330328

@@ -350,7 +348,6 @@ export class ReactAgent<
350348
const toolNode = new ToolNode(toolClasses.filter(isClientTool), {
351349
signal: this.options.signal,
352350
wrapToolCall: wrapToolCallHandler,
353-
getPrivateState: () => this.#agentNode.getState()._privateState,
354351
});
355352
allNodeWorkflows.addNode("tools", toolNode);
356353
}
@@ -944,7 +941,8 @@ export class ReactAgent<
944941
* Initialize middleware states if not already present in the input state.
945942
*/
946943
async #initializeMiddlewareStates(
947-
state: InvokeStateParameter<StateSchema, TMiddleware>
944+
state: InvokeStateParameter<StateSchema, TMiddleware>,
945+
config: RunnableConfig
948946
): Promise<InvokeStateParameter<StateSchema, TMiddleware>> {
949947
if (
950948
!this.options.middleware ||
@@ -959,10 +957,13 @@ export class ReactAgent<
959957
this.options.middleware,
960958
state
961959
);
962-
const updatedState = { ...state } as InvokeStateParameter<
963-
StateSchema,
964-
TMiddleware
965-
>;
960+
const threadState = await this.#graph
961+
.getState(config)
962+
.catch(() => ({ values: {} }));
963+
const updatedState = {
964+
...threadState.values,
965+
...state,
966+
} as InvokeStateParameter<StateSchema, TMiddleware>;
966967
if (!updatedState) {
967968
return updatedState;
968969
}
@@ -977,35 +978,6 @@ export class ReactAgent<
977978
return updatedState;
978979
}
979980

980-
/**
981-
* Populate the private state of the agent node from the previous state.
982-
*/
983-
async #populatePrivateState(config?: RunnableConfig) {
984-
/**
985-
* not needed if thread_id is not provided
986-
*/
987-
if (!config?.configurable?.thread_id) {
988-
return;
989-
}
990-
const prevState = (await this.#graph.getState(config as any)) as {
991-
values: {
992-
_privateState: PrivateState;
993-
};
994-
};
995-
996-
/**
997-
* not need if state is empty
998-
*/
999-
if (!prevState.values._privateState) {
1000-
return;
1001-
}
1002-
1003-
this.#agentNode.setState({
1004-
structuredResponse: undefined,
1005-
_privateState: prevState.values._privateState,
1006-
});
1007-
}
1008-
1009981
/**
1010982
* Executes the agent with the given state and returns the final state after all processing.
1011983
*
@@ -1061,8 +1033,10 @@ export class ReactAgent<
10611033
StructuredResponseFormat,
10621034
TMiddleware
10631035
>;
1064-
const initializedState = await this.#initializeMiddlewareStates(state);
1065-
await this.#populatePrivateState(config);
1036+
const initializedState = await this.#initializeMiddlewareStates(
1037+
state,
1038+
config as RunnableConfig
1039+
);
10661040

10671041
return this.#graph.invoke(
10681042
initializedState,
@@ -1120,7 +1094,10 @@ export class ReactAgent<
11201094
InferMiddlewareContextInputs<TMiddleware>
11211095
>
11221096
): Promise<IterableReadableStream<any>> {
1123-
const initializedState = await this.#initializeMiddlewareStates(state);
1097+
const initializedState = await this.#initializeMiddlewareStates(
1098+
state,
1099+
config as RunnableConfig
1100+
);
11241101
return this.#graph.stream(initializedState, config as Record<string, any>);
11251102
}
11261103

libs/langchain/src/agents/annotation.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ export function createAgentAnnotationConditional<
3030
const zodSchema: Record<string, any> = {
3131
messages: withLangGraph(z.custom<BaseMessage[]>(), MessagesZodMeta),
3232
jumpTo: z
33-
.union([z.literal("model_request"), z.literal("tools"), z.undefined()])
33+
.union([
34+
z.literal("model_request"),
35+
z.literal("tools"),
36+
z.literal("end"),
37+
z.undefined(),
38+
])
3439
.optional(),
3540
};
3641

libs/langchain/src/agents/middleware/modelCallLimit.ts

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ export type ModelCallLimitMiddlewareConfig = Partial<
2727
InferInteropZodInput<typeof contextSchema>
2828
>;
2929

30+
/**
31+
* Middleware state schema to track the number of model calls made at the thread and run level.
32+
*/
33+
const stateSchema = z.object({
34+
threadModelCallCount: z.number().default(0),
35+
runModelCallCount: z.number().default(0),
36+
});
37+
3038
/**
3139
* Error thrown when the model call limit is exceeded.
3240
*
@@ -133,6 +141,7 @@ export function modelCallLimitMiddleware(
133141
return createMiddleware({
134142
name: "ModelCallLimitMiddleware",
135143
contextSchema,
144+
stateSchema,
136145
beforeModel: {
137146
canJumpTo: ["end"],
138147
hook: (state, runtime) => {
@@ -145,13 +154,12 @@ export function modelCallLimitMiddleware(
145154
const runLimit =
146155
runtime.context.runLimit ?? middlewareOptions?.runLimit;
147156

148-
if (
149-
typeof threadLimit === "number" &&
150-
threadLimit <= runtime.threadLevelCallCount
151-
) {
157+
const threadCount = state.threadModelCallCount;
158+
const runCount = state.runModelCallCount;
159+
if (typeof threadLimit === "number" && threadLimit <= threadCount) {
152160
const error = new ModelCallLimitMiddlewareError({
153161
threadLimit,
154-
threadCount: runtime.threadLevelCallCount,
162+
threadCount,
155163
});
156164
if (exitBehavior === "end") {
157165
return {
@@ -162,13 +170,10 @@ export function modelCallLimitMiddleware(
162170

163171
throw error;
164172
}
165-
if (
166-
typeof runLimit === "number" &&
167-
runLimit <= runtime.runModelCallCount
168-
) {
173+
if (typeof runLimit === "number" && runLimit <= runCount) {
169174
const error = new ModelCallLimitMiddlewareError({
170175
runLimit,
171-
runCount: runtime.runModelCallCount,
176+
runCount,
172177
});
173178
if (exitBehavior === "end") {
174179
return {
@@ -183,5 +188,9 @@ export function modelCallLimitMiddleware(
183188
return state;
184189
},
185190
},
191+
afterModel: (state) => ({
192+
runModelCallCount: state.runModelCallCount + 1,
193+
threadModelCallCount: state.threadModelCallCount + 1,
194+
}),
186195
});
187196
}

libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ const tools = [
5252
name: "tool_1",
5353
description: "tool_1",
5454
}),
55+
tool(() => "barfoo", {
56+
name: "tool_2",
57+
description: "tool_2",
58+
}),
5559
];
5660

5761
describe("ModelCallLimitMiddleware", () => {
@@ -143,14 +147,19 @@ describe("ModelCallLimitMiddleware", () => {
143147
checkpointer,
144148
});
145149
if (exitBehavior === "throw") {
146-
await expect(
147-
agent2.invoke({ messages: ["Hello, world!"] }, config)
148-
).resolves.not.toThrow();
150+
const result = await agent2.invoke(
151+
{ messages: ["Hello, world!"] },
152+
config
153+
);
154+
await expect(result.runModelCallCount).toBe(3);
155+
await expect(result.threadModelCallCount).toBe(3);
149156
} else {
150157
const result = await agent2.invoke(
151158
{ messages: ["Hello, world!"] },
152159
config
153160
);
161+
await expect(result.runModelCallCount).toBe(3);
162+
await expect(result.threadModelCallCount).toBe(3);
154163
expect(result.messages.at(-1)?.content).not.toContain(
155164
"Model call limits exceeded"
156165
);

0 commit comments

Comments
 (0)