Skip to content

Commit 401bb68

Browse files
authored
feat(langchain/agents): readd state schema param (#9217)
1 parent 0c29928 commit 401bb68

File tree

9 files changed

+439
-89
lines changed

9 files changed

+439
-89
lines changed

libs/langchain/src/agents/ReactAgent.ts

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,30 +57,40 @@ import type {
5757
InferMiddlewareStates,
5858
InferMiddlewareInputStates,
5959
InferContextInput,
60-
ToAnnotationRoot,
6160
AnyAnnotationRoot,
61+
InferSchemaInput,
62+
ToAnnotationRoot,
6263
} from "./middleware/types.js";
6364
import { type ResponseFormatUndefined } from "./responses.js";
6465

6566
// Helper type to get the state definition with middleware states
6667
type MergedAgentState<
68+
StateSchema extends AnyAnnotationRoot | InteropZodObject | undefined,
6769
StructuredResponseFormat extends
6870
| Record<string, any>
6971
| ResponseFormatUndefined,
7072
TMiddleware extends readonly AgentMiddleware[]
71-
> = (StructuredResponseFormat extends ResponseFormatUndefined
72-
? Omit<BuiltInState, "jumpTo">
73-
: Omit<BuiltInState, "jumpTo"> & {
74-
structuredResponse: StructuredResponseFormat;
75-
}) &
73+
> = InferSchemaInput<StateSchema> &
74+
(StructuredResponseFormat extends ResponseFormatUndefined
75+
? Omit<BuiltInState, "jumpTo">
76+
: Omit<BuiltInState, "jumpTo"> & {
77+
structuredResponse: StructuredResponseFormat;
78+
}) &
7679
InferMiddlewareStates<TMiddleware>;
7780

78-
type InvokeStateParameter<TMiddleware extends readonly AgentMiddleware[]> =
79-
| (UserInput & InferMiddlewareInputStates<TMiddleware>)
81+
type InvokeStateParameter<
82+
StateSchema extends AnyAnnotationRoot | InteropZodObject | undefined,
83+
TMiddleware extends readonly AgentMiddleware[]
84+
> =
85+
| (UserInput<StateSchema> & InferMiddlewareInputStates<TMiddleware>)
8086
| Command<any, any, any>
8187
| null;
8288

8389
type AgentGraph<
90+
StateSchema extends
91+
| AnyAnnotationRoot
92+
| InteropZodObject
93+
| undefined = undefined,
8494
StructuredResponseFormat extends
8595
| Record<string, any>
8696
| ResponseFormatUndefined = Record<string, any>,
@@ -93,7 +103,7 @@ type AgentGraph<
93103
any,
94104
any,
95105
any,
96-
MergedAgentState<StructuredResponseFormat, TMiddleware>,
106+
MergedAgentState<StateSchema, StructuredResponseFormat, TMiddleware>,
97107
ToAnnotationRoot<ContextSchema>["spec"],
98108
unknown
99109
>;
@@ -102,19 +112,32 @@ export class ReactAgent<
102112
StructuredResponseFormat extends
103113
| Record<string, any>
104114
| ResponseFormatUndefined = Record<string, any>,
115+
StateSchema extends
116+
| AnyAnnotationRoot
117+
| InteropZodObject
118+
| undefined = undefined,
105119
ContextSchema extends
106120
| AnyAnnotationRoot
107121
| InteropZodObject = AnyAnnotationRoot,
108122
TMiddleware extends readonly AgentMiddleware[] = readonly AgentMiddleware[]
109123
> {
110-
#graph: AgentGraph<StructuredResponseFormat, ContextSchema, TMiddleware>;
124+
#graph: AgentGraph<
125+
StateSchema,
126+
StructuredResponseFormat,
127+
ContextSchema,
128+
TMiddleware
129+
>;
111130

112131
#toolBehaviorVersion: "v1" | "v2" = "v2";
113132

114133
#agentNode: AgentNode<any, AnyAnnotationRoot>;
115134

116135
constructor(
117-
public options: CreateAgentParams<StructuredResponseFormat, ContextSchema>
136+
public options: CreateAgentParams<
137+
StructuredResponseFormat,
138+
StateSchema,
139+
ContextSchema
140+
>
118141
) {
119142
this.#toolBehaviorVersion = options.version ?? this.#toolBehaviorVersion;
120143

@@ -155,8 +178,9 @@ export class ReactAgent<
155178
* Create a schema that merges agent base schema with middleware state schemas
156179
* Using Zod with withLangGraph ensures LangGraph Studio gets proper metadata
157180
*/
158-
const schema = createAgentAnnotationConditional<TMiddleware>(
181+
const schema = createAgentAnnotationConditional<StateSchema, TMiddleware>(
159182
this.options.responseFormat !== undefined,
183+
this.options.stateSchema as StateSchema,
160184
this.options.middleware as TMiddleware
161185
);
162186

@@ -595,13 +619,19 @@ export class ReactAgent<
595619
store: this.options.store,
596620
name: this.options.name,
597621
description: this.options.description,
598-
}) as AgentGraph<StructuredResponseFormat, ContextSchema, TMiddleware>;
622+
}) as AgentGraph<
623+
StateSchema,
624+
StructuredResponseFormat,
625+
ContextSchema,
626+
TMiddleware
627+
>;
599628
}
600629

601630
/**
602631
* Get the compiled {@link https://docs.langchain.com/oss/javascript/langgraph/use-graph-api | StateGraph}.
603632
*/
604633
get graph(): AgentGraph<
634+
StateSchema,
605635
StructuredResponseFormat,
606636
ContextSchema,
607637
TMiddleware
@@ -913,8 +943,8 @@ export class ReactAgent<
913943
* Initialize middleware states if not already present in the input state.
914944
*/
915945
async #initializeMiddlewareStates(
916-
state: InvokeStateParameter<TMiddleware>
917-
): Promise<InvokeStateParameter<TMiddleware>> {
946+
state: InvokeStateParameter<StateSchema, TMiddleware>
947+
): Promise<InvokeStateParameter<StateSchema, TMiddleware>> {
918948
if (
919949
!this.options.middleware ||
920950
this.options.middleware.length === 0 ||
@@ -928,7 +958,10 @@ export class ReactAgent<
928958
this.options.middleware,
929959
state
930960
);
931-
const updatedState = { ...state } as InvokeStateParameter<TMiddleware>;
961+
const updatedState = { ...state } as InvokeStateParameter<
962+
StateSchema,
963+
TMiddleware
964+
>;
932965
if (!updatedState) {
933966
return updatedState;
934967
}
@@ -1016,13 +1049,17 @@ export class ReactAgent<
10161049
* ```
10171050
*/
10181051
async invoke(
1019-
state: InvokeStateParameter<TMiddleware>,
1052+
state: InvokeStateParameter<StateSchema, TMiddleware>,
10201053
config?: InvokeConfiguration<
10211054
InferContextInput<ContextSchema> &
10221055
InferMiddlewareContextInputs<TMiddleware>
10231056
>
10241057
) {
1025-
type FullState = MergedAgentState<StructuredResponseFormat, TMiddleware>;
1058+
type FullState = MergedAgentState<
1059+
StateSchema,
1060+
StructuredResponseFormat,
1061+
TMiddleware
1062+
>;
10261063
const initializedState = await this.#initializeMiddlewareStates(state);
10271064
await this.#populatePrivateState(config);
10281065

@@ -1076,7 +1113,7 @@ export class ReactAgent<
10761113
* ```
10771114
*/
10781115
async stream(
1079-
state: InvokeStateParameter<TMiddleware>,
1116+
state: InvokeStateParameter<StateSchema, TMiddleware>,
10801117
config?: StreamConfiguration<
10811118
InferContextInput<ContextSchema> &
10821119
InferMiddlewareContextInputs<TMiddleware>
@@ -1142,7 +1179,7 @@ export class ReactAgent<
11421179
* @internal
11431180
*/
11441181
streamEvents(
1145-
state: InvokeStateParameter<TMiddleware>,
1182+
state: InvokeStateParameter<StateSchema, TMiddleware>,
11461183
config?: StreamConfiguration<
11471184
InferContextInput<ContextSchema> &
11481185
InferMiddlewareContextInputs<TMiddleware>

libs/langchain/src/agents/annotation.ts

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@ import {
99
} from "@langchain/langgraph";
1010
import { withLangGraph } from "@langchain/langgraph/zod";
1111

12-
import type { AgentMiddleware } from "./middleware/types.js";
12+
import type { AgentMiddleware, AnyAnnotationRoot } from "./middleware/types.js";
13+
import { InteropZodObject } from "@langchain/core/utils/types";
1314

1415
export function createAgentAnnotationConditional<
16+
TStateSchema extends
17+
| AnyAnnotationRoot
18+
| InteropZodObject
19+
| undefined = undefined,
1520
TMiddleware extends readonly AgentMiddleware<any, any, any>[] = []
1621
>(
1722
hasStructuredResponse = true,
23+
stateSchema: TStateSchema,
1824
middlewareList: TMiddleware = [] as unknown as TMiddleware
1925
) {
2026
/**
@@ -28,25 +34,31 @@ export function createAgentAnnotationConditional<
2834
.optional(),
2935
};
3036

31-
/**
32-
* Add middleware state properties to the Zod schema
33-
*/
34-
for (const middleware of middlewareList) {
35-
if (middleware.stateSchema) {
36-
const { shape } = middleware.stateSchema;
37-
for (const [key, schema] of Object.entries(shape)) {
38-
/**
39-
* Skip private state properties
40-
*/
41-
if (key.startsWith("_")) {
42-
continue;
43-
}
37+
const applySchema = (schema: { shape: Record<string, any> }) => {
38+
const { shape } = schema;
39+
for (const [key, schema] of Object.entries(shape)) {
40+
/**
41+
* Skip private state properties
42+
*/
43+
if (key.startsWith("_")) {
44+
continue;
45+
}
4446

45-
if (!(key in zodSchema)) {
46-
zodSchema[key] = schema;
47-
}
47+
if (!(key in zodSchema)) {
48+
zodSchema[key] = schema;
4849
}
4950
}
51+
};
52+
53+
// Add state schema properties to the Zod schema
54+
if (stateSchema && "shape" in stateSchema) {
55+
applySchema(stateSchema);
56+
}
57+
58+
for (const middleware of middlewareList) {
59+
if (middleware.stateSchema) {
60+
applySchema(middleware.stateSchema);
61+
}
5062
}
5163

5264
// Only include structuredResponse when responseFormat is defined
@@ -77,4 +89,5 @@ export const PreHookAnnotation: AnnotationRoot<{
7789
llmInputMessages: BinaryOperatorAggregate<BaseMessage[], Messages>;
7890
messages: BinaryOperatorAggregate<BaseMessage[], Messages>;
7991
}>;
92+
8093
export type PreHookAnnotation = typeof PreHookAnnotation;

0 commit comments

Comments
 (0)