Skip to content

Commit 42930b5

Browse files
fix(langchain): improved state schema typing (langchain-ai#9285)
1 parent 59fa6e9 commit 42930b5

File tree

7 files changed

+79
-64
lines changed

7 files changed

+79
-64
lines changed

.changeset/twenty-clocks-raise.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"langchain": patch
3+
---
4+
5+
fix(langchain): improved state schema typing

libs/langchain/src/agents/middleware/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ export {
4040
export {
4141
modelCallLimitMiddleware,
4242
type ModelCallLimitMiddlewareConfig,
43-
} from "./callLimit.js";
43+
} from "./modelCallLimit.js";
4444
export { modelFallbackMiddleware } from "./modelFallback.js";
4545
export {
4646
toolRetryMiddleware,

libs/langchain/src/agents/middleware/tests/callLimit.test.ts renamed to libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { tool } from "@langchain/core/tools";
44
import { MemorySaver } from "@langchain/langgraph-checkpoint";
55

66
import { FakeToolCallingChatModel } from "../../tests/utils.js";
7-
import { modelCallLimitMiddleware } from "../callLimit.js";
7+
import { modelCallLimitMiddleware } from "../modelCallLimit.js";
88
import { createAgent } from "../../index.js";
99

1010
const toolCallMessage1 = new AIMessage({

libs/langchain/src/agents/middleware/types.ts

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,24 @@ type PromiseOrValue<T> = T | Promise<T>;
2121

2222
export type AnyAnnotationRoot = AnnotationRoot<any>;
2323

24-
type NormalizedSchemaInput<TSchema extends InteropZodObject | undefined = any> =
25-
TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {};
24+
type NormalizedSchemaInput<
25+
TSchema extends InteropZodObject | undefined | never = any
26+
> = [TSchema] extends [never]
27+
? AgentBuiltInState
28+
: TSchema extends InteropZodObject
29+
? InferInteropZodOutput<TSchema> & AgentBuiltInState
30+
: TSchema extends Record<string, unknown>
31+
? TSchema & AgentBuiltInState
32+
: AgentBuiltInState;
2633

2734
/**
2835
* Result type for middleware functions.
2936
*/
30-
export type MiddlewareResult<TState> = TState | void;
37+
export type MiddlewareResult<TState> =
38+
| (TState & {
39+
jumpTo?: JumpToTarget;
40+
})
41+
| void;
3142

3243
/**
3344
* Represents a tool call request for the wrapToolCall hook.
@@ -61,28 +72,22 @@ export interface ToolCallRequest<
6172
* Takes a tool call request and returns the tool result or a command.
6273
*/
6374
export type ToolCallHandler<
64-
TSchema extends InteropZodObject | undefined = any,
75+
TSchema extends Record<string, unknown> = AgentBuiltInState,
6576
TContext = unknown
6677
> = (
67-
request: ToolCallRequest<
68-
NormalizedSchemaInput<TSchema> & AgentBuiltInState,
69-
TContext
70-
>
78+
request: ToolCallRequest<TSchema, TContext>
7179
) => PromiseOrValue<ToolMessage | Command>;
7280

7381
/**
7482
* Wrapper function type for the wrapToolCall hook.
7583
* Allows middleware to intercept and modify tool execution.
7684
*/
7785
export type WrapToolCallHook<
78-
TSchema extends InteropZodObject | undefined = any,
86+
TSchema extends InteropZodObject | undefined = undefined,
7987
TContext = unknown
8088
> = (
81-
request: ToolCallRequest<
82-
NormalizedSchemaInput<TSchema> & AgentBuiltInState,
83-
TContext
84-
>,
85-
handler: ToolCallHandler<TSchema, TContext>
89+
request: ToolCallRequest<NormalizedSchemaInput<TSchema>, TContext>,
90+
handler: ToolCallHandler<NormalizedSchemaInput<TSchema>, TContext>
8691
) => PromiseOrValue<ToolMessage | Command>;
8792

8893
/**
@@ -93,13 +98,10 @@ export type WrapToolCallHook<
9398
* @returns The AI message response from the model
9499
*/
95100
export type WrapModelCallHandler<
96-
TSchema extends InteropZodObject | undefined = any,
101+
TSchema extends InteropZodObject | undefined = undefined,
97102
TContext = unknown
98103
> = (
99-
request: ModelRequest<
100-
NormalizedSchemaInput<TSchema> & AgentBuiltInState,
101-
TContext
102-
>
104+
request: ModelRequest<NormalizedSchemaInput<TSchema>, TContext>
103105
) => PromiseOrValue<AIMessage>;
104106

105107
/**
@@ -116,13 +118,10 @@ export type WrapModelCallHandler<
116118
* @returns The AI message response from the model (or a modified version)
117119
*/
118120
export type WrapModelCallHook<
119-
TSchema extends InteropZodObject | undefined = any,
121+
TSchema extends InteropZodObject | undefined = undefined,
120122
TContext = unknown
121123
> = (
122-
request: ModelRequest<
123-
NormalizedSchemaInput<TSchema> & AgentBuiltInState,
124-
TContext
125-
>,
124+
request: ModelRequest<NormalizedSchemaInput<TSchema>, TContext>,
126125
handler: WrapModelCallHandler<TSchema, TContext>
127126
) => PromiseOrValue<AIMessage>;
128127

@@ -134,26 +133,23 @@ export type WrapModelCallHook<
134133
* @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc.
135134
* @returns A middleware result containing partial state updates or undefined to pass through
136135
*/
137-
export type BeforeAgentHandler<
138-
TSchema extends InteropZodObject | undefined = any,
139-
TContext = unknown
140-
> = (
141-
state: NormalizedSchemaInput<TSchema> & AgentBuiltInState,
136+
type BeforeAgentHandler<TSchema, TContext> = (
137+
state: TSchema,
142138
runtime: Runtime<TContext>
143-
) => PromiseOrValue<MiddlewareResult<Partial<NormalizedSchemaInput<TSchema>>>>;
139+
) => PromiseOrValue<MiddlewareResult<Partial<TSchema>>>;
144140

145141
/**
146142
* Hook type for the beforeAgent lifecycle event.
147143
* Can be either a handler function or an object with a handler and optional jump targets.
148144
* This hook is called once at the start of the agent invocation.
149145
*/
150146
export type BeforeAgentHook<
151-
TSchema extends InteropZodObject | undefined = any,
147+
TSchema extends InteropZodObject | undefined = undefined,
152148
TContext = unknown
153149
> =
154-
| BeforeAgentHandler<TSchema, TContext>
150+
| BeforeAgentHandler<NormalizedSchemaInput<TSchema>, TContext>
155151
| {
156-
hook: BeforeAgentHandler<TSchema, TContext>;
152+
hook: BeforeAgentHandler<NormalizedSchemaInput<TSchema>, TContext>;
157153
canJumpTo?: JumpToTarget[];
158154
};
159155

@@ -165,26 +161,23 @@ export type BeforeAgentHook<
165161
* @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc.
166162
* @returns A middleware result containing partial state updates or undefined to pass through
167163
*/
168-
export type BeforeModelHandler<
169-
TSchema extends InteropZodObject | undefined = any,
170-
TContext = unknown
171-
> = (
172-
state: NormalizedSchemaInput<TSchema> & AgentBuiltInState,
164+
type BeforeModelHandler<TSchema, TContext> = (
165+
state: TSchema,
173166
runtime: Runtime<TContext>
174-
) => PromiseOrValue<MiddlewareResult<Partial<NormalizedSchemaInput<TSchema>>>>;
167+
) => PromiseOrValue<MiddlewareResult<Partial<TSchema>>>;
175168

176169
/**
177170
* Hook type for the beforeModel lifecycle event.
178171
* Can be either a handler function or an object with a handler and optional jump targets.
179172
* This hook is called before each model invocation.
180173
*/
181174
export type BeforeModelHook<
182-
TSchema extends InteropZodObject | undefined = any,
175+
TSchema extends InteropZodObject | undefined = undefined,
183176
TContext = unknown
184177
> =
185-
| BeforeModelHandler<TSchema, TContext>
178+
| BeforeModelHandler<NormalizedSchemaInput<TSchema>, TContext>
186179
| {
187-
hook: BeforeModelHandler<TSchema, TContext>;
180+
hook: BeforeModelHandler<NormalizedSchemaInput<TSchema>, TContext>;
188181
canJumpTo?: JumpToTarget[];
189182
};
190183

@@ -197,26 +190,23 @@ export type BeforeModelHook<
197190
* @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc.
198191
* @returns A middleware result containing partial state updates or undefined to pass through
199192
*/
200-
export type AfterModelHandler<
201-
TSchema extends InteropZodObject | undefined = any,
202-
TContext = unknown
203-
> = (
204-
state: NormalizedSchemaInput<TSchema> & AgentBuiltInState,
193+
type AfterModelHandler<TSchema, TContext> = (
194+
state: TSchema,
205195
runtime: Runtime<TContext>
206-
) => PromiseOrValue<MiddlewareResult<Partial<NormalizedSchemaInput<TSchema>>>>;
196+
) => PromiseOrValue<MiddlewareResult<Partial<TSchema>>>;
207197

208198
/**
209199
* Hook type for the afterModel lifecycle event.
210200
* Can be either a handler function or an object with a handler and optional jump targets.
211201
* This hook is called after each model invocation.
212202
*/
213203
export type AfterModelHook<
214-
TSchema extends InteropZodObject | undefined = any,
204+
TSchema extends InteropZodObject | undefined = undefined,
215205
TContext = unknown
216206
> =
217-
| AfterModelHandler<TSchema, TContext>
207+
| AfterModelHandler<NormalizedSchemaInput<TSchema>, TContext>
218208
| {
219-
hook: AfterModelHandler<TSchema, TContext>;
209+
hook: AfterModelHandler<NormalizedSchemaInput<TSchema>, TContext>;
220210
canJumpTo?: JumpToTarget[];
221211
};
222212

@@ -228,26 +218,23 @@ export type AfterModelHook<
228218
* @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc.
229219
* @returns A middleware result containing partial state updates or undefined to pass through
230220
*/
231-
export type AfterAgentHandler<
232-
TSchema extends InteropZodObject | undefined = any,
233-
TContext = unknown
234-
> = (
235-
state: NormalizedSchemaInput<TSchema> & AgentBuiltInState,
221+
type AfterAgentHandler<TSchema, TContext> = (
222+
state: TSchema,
236223
runtime: Runtime<TContext>
237-
) => PromiseOrValue<MiddlewareResult<Partial<NormalizedSchemaInput<TSchema>>>>;
224+
) => PromiseOrValue<MiddlewareResult<Partial<TSchema>>>;
238225

239226
/**
240227
* Hook type for the afterAgent lifecycle event.
241228
* Can be either a handler function or an object with a handler and optional jump targets.
242229
* This hook is called once at the end of the agent invocation.
243230
*/
244231
export type AfterAgentHook<
245-
TSchema extends InteropZodObject | undefined = any,
232+
TSchema extends InteropZodObject | undefined = undefined,
246233
TContext = unknown
247234
> =
248-
| AfterAgentHandler<TSchema, TContext>
235+
| AfterAgentHandler<NormalizedSchemaInput<TSchema>, TContext>
249236
| {
250-
hook: AfterAgentHandler<TSchema, TContext>;
237+
hook: AfterAgentHandler<NormalizedSchemaInput<TSchema>, TContext>;
251238
canJumpTo?: JumpToTarget[];
252239
};
253240

libs/langchain/src/agents/tests/middleware.test-d.ts

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { HumanMessage, BaseMessage, AIMessage } from "@langchain/core/messages";
44
import { tool } from "@langchain/core/tools";
55

66
import { createAgent, createMiddleware } from "../index.js";
7+
import type { AgentBuiltInState } from "../runtime.js";
78
import type { ServerTool, ClientTool } from "../tools.js";
89

910
describe("middleware types", () => {
@@ -169,14 +170,33 @@ describe("middleware types", () => {
169170
.default({
170171
customRequiredContextProp: "default value",
171172
}),
172-
beforeModel: async (_state, runtime) => {
173+
stateSchema: z.object({
174+
customDefaultStateProp: z.string().default("default value"),
175+
customOptionalStateProp: z.string().optional(),
176+
customRequiredStateProp: z.string(),
177+
}),
178+
beforeModel: async (state, runtime) => {
179+
expectTypeOf(state).toEqualTypeOf<
180+
{
181+
customDefaultStateProp: string;
182+
customOptionalStateProp?: string;
183+
customRequiredStateProp: string;
184+
} & AgentBuiltInState
185+
>();
173186
expectTypeOf(runtime.context).toEqualTypeOf<{
174187
customDefaultContextProp: string;
175188
customOptionalContextProp?: string;
176189
customRequiredContextProp: string;
177190
}>();
178191
},
179-
afterModel: async (_state, runtime) => {
192+
afterModel: async (state, runtime) => {
193+
expectTypeOf(state).toEqualTypeOf<
194+
{
195+
customDefaultStateProp: string;
196+
customOptionalStateProp?: string;
197+
customRequiredStateProp: string;
198+
} & AgentBuiltInState
199+
>();
180200
expectTypeOf(runtime.context).toEqualTypeOf<{
181201
customDefaultContextProp: string;
182202
customOptionalContextProp?: string;
@@ -209,6 +229,7 @@ describe("middleware types", () => {
209229
await agent.invoke(
210230
{
211231
messages: [new HumanMessage("Hello, world!")],
232+
customRequiredStateProp: "default value",
212233
},
213234
{
214235
configurable: {

libs/langchain/src/agents/tests/state.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ describe("middleware state management", () => {
146146
const model = new FakeToolCallingModel({});
147147
const middleware = createMiddleware({
148148
name: "middleware",
149+
// @ts-expect-error _privateState is not an expected return type
149150
beforeModel: async (_, runtime) => {
150151
expect(runtime.threadLevelCallCount).toBe(0);
151152
expect(runtime.runModelCallCount).toBe(0);
@@ -165,6 +166,7 @@ describe("middleware state management", () => {
165166
expect(request.runtime.runModelCallCount).toBe(0);
166167
return handler(request);
167168
},
169+
// @ts-expect-error _privateState is not an expected return type
168170
afterModel: async (_, runtime) => {
169171
expect(runtime.threadLevelCallCount).toBe(1);
170172
expect(runtime.runModelCallCount).toBe(1);

0 commit comments

Comments
 (0)