Skip to content

Commit 797650f

Browse files
authored
feat(langchain): Add type inference for ReactAgent stream method (#9368)
1 parent 671a949 commit 797650f

File tree

3 files changed

+88
-12
lines changed

3 files changed

+88
-12
lines changed

libs/langchain/src/agents/ReactAgent.ts

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import {
1111
CompiledStateGraph,
1212
type GetStateOptions,
1313
type LangGraphRunnableConfig,
14+
type StreamMode,
15+
type StreamOutputMap,
1416
} from "@langchain/langgraph";
1517
import type { CheckpointListOptions } from "@langchain/langgraph-checkpoint";
1618
import { ToolMessage, AIMessage } from "@langchain/core/messages";
@@ -1076,18 +1078,39 @@ export class ReactAgent<
10761078
* }
10771079
* ```
10781080
*/
1079-
async stream(
1081+
async stream<
1082+
TStreamMode extends StreamMode | StreamMode[] | undefined,
1083+
TEncoding extends "text/event-stream" | undefined
1084+
>(
10801085
state: InvokeStateParameter<StateSchema, TMiddleware>,
10811086
config?: StreamConfiguration<
10821087
InferContextInput<ContextSchema> &
1083-
InferMiddlewareContextInputs<TMiddleware>
1088+
InferMiddlewareContextInputs<TMiddleware>,
1089+
TStreamMode,
1090+
TEncoding
10841091
>
1085-
): Promise<IterableReadableStream<any>> {
1092+
) {
10861093
const initializedState = await this.#initializeMiddlewareStates(
10871094
state,
10881095
config as RunnableConfig
10891096
);
1090-
return this.#graph.stream(initializedState, config as Record<string, any>);
1097+
return this.#graph.stream(
1098+
initializedState,
1099+
config as Record<string, any>
1100+
) as Promise<
1101+
IterableReadableStream<
1102+
StreamOutputMap<
1103+
TStreamMode,
1104+
false,
1105+
MergedAgentState<StateSchema, StructuredResponseFormat, TMiddleware>,
1106+
MergedAgentState<StateSchema, StructuredResponseFormat, TMiddleware>,
1107+
string,
1108+
unknown,
1109+
unknown,
1110+
TEncoding
1111+
>
1112+
>
1113+
>;
10911114
}
10921115

10931116
/**
@@ -1149,7 +1172,9 @@ export class ReactAgent<
11491172
state: InvokeStateParameter<StateSchema, TMiddleware>,
11501173
config?: StreamConfiguration<
11511174
InferContextInput<ContextSchema> &
1152-
InferMiddlewareContextInputs<TMiddleware>
1175+
InferMiddlewareContextInputs<TMiddleware>,
1176+
StreamMode | StreamMode[] | undefined,
1177+
"text/event-stream" | undefined
11531178
> & { version?: "v1" | "v2" },
11541179
streamOptions?: Parameters<Runnable["streamEvents"]>[2]
11551180
): IterableReadableStream<StreamEvent> {

libs/langchain/src/agents/runtime.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import type { InteropZodDefault } from "@langchain/core/utils/types";
44
import type {
55
Runtime as LangGraphRuntime,
66
PregelOptions,
7+
StreamMode,
78
} from "@langchain/langgraph";
89
import type { BaseMessage } from "@langchain/core/messages";
910
import type { BaseCallbackConfig } from "@langchain/core/callbacks/manager";
@@ -151,12 +152,21 @@ export type InvokeConfiguration<ContextSchema extends Record<string, any>> =
151152
Partial<Pick<PregelOptions<any, any, any>, CreateAgentPregelOptions>> &
152153
WithMaybeContext<ContextSchema>;
153154

154-
export type StreamConfiguration<ContextSchema extends Record<string, any>> =
155+
export type StreamConfiguration<
156+
ContextSchema extends Record<string, any>,
157+
TStreamMode extends StreamMode | StreamMode[] | undefined,
158+
TEncoding extends "text/event-stream" | undefined
159+
> =
155160
/**
156161
* If the context schema is a default object, `context` can be optional
157162
*/
158163
ContextSchema extends InteropZodDefault<any>
159-
? Partial<Pick<PregelOptions<any, any, any>, CreateAgentPregelOptions>> & {
164+
? Partial<
165+
Pick<
166+
PregelOptions<any, any, any, TStreamMode, boolean, TEncoding>,
167+
CreateAgentPregelOptions
168+
>
169+
> & {
160170
context?: Partial<ContextSchema>;
161171
}
162172
: /**
@@ -165,15 +175,15 @@ export type StreamConfiguration<ContextSchema extends Record<string, any>> =
165175
IsAllOptional<ContextSchema> extends true
166176
? Partial<
167177
Pick<
168-
PregelOptions<any, any, any>,
178+
PregelOptions<any, any, any, TStreamMode, boolean, TEncoding>,
169179
CreateAgentPregelOptions | CreateAgentPregelStreamOptions
170180
>
171181
> & {
172182
context?: Partial<ContextSchema>;
173183
}
174184
: Partial<
175185
Pick<
176-
PregelOptions<any, any, any>,
186+
PregelOptions<any, any, any, TStreamMode, boolean, TEncoding>,
177187
CreateAgentPregelOptions | CreateAgentPregelStreamOptions
178188
>
179189
> &

libs/langchain/src/agents/tests/reactAgent.test-d.ts

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import { LanguageModelLike } from "@langchain/core/language_models/base";
44
import { describe, it, expectTypeOf } from "vitest";
55
import type { IterableReadableStream } from "@langchain/core/utils/stream";
66

7-
import { createAgent } from "../index.js";
7+
import { type BuiltInState, createAgent } from "../index.js";
8+
import type { StreamOutputMap } from "@langchain/langgraph";
89

910
describe("reactAgent", () => {
1011
it("should require model as only required property", async () => {
@@ -79,8 +80,48 @@ describe("reactAgent", () => {
7980
recursionLimit: 10,
8081
}
8182
);
82-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
83-
expectTypeOf(stream).toEqualTypeOf<IterableReadableStream<any>>();
83+
expectTypeOf(stream).toEqualTypeOf<
84+
IterableReadableStream<
85+
StreamOutputMap<
86+
"values" | "updates" | "messages",
87+
false,
88+
Record<string, unknown>,
89+
Record<string, unknown>,
90+
string,
91+
unknown,
92+
unknown,
93+
"text/event-stream"
94+
>
95+
>
96+
>();
97+
98+
for await (const chunk of stream) {
99+
expectTypeOf(chunk).toEqualTypeOf<Uint8Array>();
100+
}
101+
102+
const multiModeStream = await agent.stream(
103+
{
104+
messages: [new HumanMessage("Hello, world!")],
105+
},
106+
{
107+
streamMode: ["updates", "messages", "values"],
108+
}
109+
);
110+
111+
for await (const chunk of multiModeStream) {
112+
const [mode, value] = chunk;
113+
expectTypeOf(mode).toEqualTypeOf<"updates" | "messages" | "values">();
114+
if (mode === "messages") {
115+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
116+
expectTypeOf(value).toEqualTypeOf<[BaseMessage, Record<string, any>]>();
117+
} else if (mode === "updates") {
118+
expectTypeOf(value).toEqualTypeOf<
119+
Record<string, Omit<BuiltInState, "jumpTo">>
120+
>();
121+
} else {
122+
expectTypeOf(value.messages).toEqualTypeOf<BaseMessage[]>();
123+
}
124+
}
84125

85126
await agent.invoke(
86127
{

0 commit comments

Comments
 (0)