Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
AsyncModelCallHandler,
JumpTo,
ModelCallHandler,
ModelRequest,
ModelResponse,
OmitFromSchema,
Expand Down Expand Up @@ -86,13 +88,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[ContextT], ModelCallHandler[ContextT]],
ModelResponse | AIMessage,
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[ContextT], ModelCallHandler[ContextT]],
ModelResponse,
]
| None
Expand Down Expand Up @@ -140,8 +142,8 @@ def retry(req, state, runtime, handler):
single_handler = handlers[0]

def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[ContextT],
handler: ModelCallHandler[ContextT],
) -> ModelResponse:
result = single_handler(request, handler)
return _normalize_to_model_response(result)
Expand All @@ -150,25 +152,25 @@ def normalized_single(

def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[ContextT], ModelCallHandler[ContextT]],
ModelResponse | AIMessage,
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[ContextT], ModelCallHandler[ContextT]],
ModelResponse | AIMessage,
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[ContextT], ModelCallHandler[ContextT]],
ModelResponse,
]:
"""Compose two handlers where outer wraps inner."""

def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[ContextT],
handler: ModelCallHandler[ContextT],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
def inner_handler(req: ModelRequest) -> ModelResponse:
def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
inner_result = inner(req, handler)
return _normalize_to_model_response(inner_result)

Expand All @@ -185,8 +187,8 @@ def inner_handler(req: ModelRequest) -> ModelResponse:

# Wrap to ensure final return type is exactly ModelResponse
def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[ContextT],
handler: ModelCallHandler[ContextT],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = result(request, handler)
Expand All @@ -198,13 +200,13 @@ def final_normalized(
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[ContextT], AsyncModelCallHandler[ContextT]],
Awaitable[ModelResponse | AIMessage],
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[ContextT], AsyncModelCallHandler[ContextT]],
Awaitable[ModelResponse],
]
| None
Expand All @@ -225,8 +227,8 @@ def _chain_async_model_call_handlers(
single_handler = handlers[0]

async def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[ContextT],
handler: AsyncModelCallHandler[ContextT],
) -> ModelResponse:
result = await single_handler(request, handler)
return _normalize_to_model_response(result)
Expand All @@ -235,25 +237,25 @@ async def normalized_single(

def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[ContextT], AsyncModelCallHandler[ContextT]],
Awaitable[ModelResponse | AIMessage],
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[ContextT], AsyncModelCallHandler[ContextT]],
Awaitable[ModelResponse | AIMessage],
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[ContextT], AsyncModelCallHandler[ContextT]],
Awaitable[ModelResponse],
]:
"""Compose two async handlers where outer wraps inner."""

async def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[ContextT],
handler: AsyncModelCallHandler[ContextT],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
async def inner_handler(req: ModelRequest) -> ModelResponse:
async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse:
inner_result = await inner(req, handler)
return _normalize_to_model_response(inner_result)

Expand All @@ -270,8 +272,8 @@ async def inner_handler(req: ModelRequest) -> ModelResponse:

# Wrap to ensure final return type is exactly ModelResponse
async def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[ContextT],
handler: AsyncModelCallHandler[ContextT],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = await result(request, handler)
Expand Down Expand Up @@ -968,7 +970,7 @@ def _handle_model_output(

return {"messages": [output]}

def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
def _get_bound_model(request: ModelRequest[ContextT]) -> tuple[Runnable, ResponseFormat | None]:
"""Get the model with appropriate tool bindings.

Performs auto-detection of strategy if needed based on model capabilities.
Expand Down Expand Up @@ -1082,7 +1084,7 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
)
return request.model.bind(**request.model_settings), None

def _execute_model_sync(request: ModelRequest) -> ModelResponse:
def _execute_model_sync(request: ModelRequest[ContextT]) -> ModelResponse:
"""Execute model and return response.

This is the core model execution logic wrapped by `wrap_model_call` handlers.
Expand Down Expand Up @@ -1133,7 +1135,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:

return state_updates

async def _execute_model_async(request: ModelRequest) -> ModelResponse:
async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse:
"""Execute model asynchronously and return response.

This is the core async model execution logic wrapped by `wrap_model_call`
Expand Down
Loading
Loading