Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
575ed9e
chore: scaffold hatchet agent, model, and mcp server
mrkaye97 Sep 16, 2025
581584e
chore: add hatchet dep
mrkaye97 Sep 16, 2025
c03c271
feat: add utils for task config
mrkaye97 Sep 16, 2025
faa891f
feat: first pass at model impl
mrkaye97 Sep 16, 2025
7bd951c
feat: tool calling implementation for mcp server
mrkaye97 Sep 16, 2025
3d4aae2
feat: implement request method for model
mrkaye97 Sep 16, 2025
88e39be
feat: more work on agent
mrkaye97 Sep 16, 2025
63893e1
feat: agent run method impl
mrkaye97 Sep 16, 2025
be36721
fix: allow arbitrary types in i/o models
mrkaye97 Sep 16, 2025
479dba9
fix: clean up types + workflow registration
mrkaye97 Sep 16, 2025
19d9e22
feat: first pass at toolsets
mrkaye97 Sep 16, 2025
ab357c3
feat: hatchet run context
mrkaye97 Sep 18, 2025
6ac4a19
chore: ignore local files for testing
mrkaye97 Sep 18, 2025
4df5d54
chore: bump hatchet
mrkaye97 Sep 18, 2025
59139f0
feat: use hatchet run context in mcp server
mrkaye97 Sep 18, 2025
830f347
feat: stricter typing on run context
mrkaye97 Sep 18, 2025
a3c4601
fix: pass types around
mrkaye97 Sep 18, 2025
dfc2064
fix: make tool and tool def serializable a la temporal impl
mrkaye97 Sep 18, 2025
449fd1a
fix: comment
mrkaye97 Sep 18, 2025
635d594
feat: more work on making the hatchet function toolset impl more simi…
mrkaye97 Sep 18, 2025
3ec8bda
fix: improve typing on tasks
mrkaye97 Sep 18, 2025
0c5cda5
fix: use tasks for everything except the agent itself
mrkaye97 Sep 18, 2025
bc0fbbc
fix: tool naming
mrkaye97 Sep 18, 2025
6e9c897
feat: add hatchet metadata on agent run
mrkaye97 Sep 18, 2025
b477dfb
feat: add run_sync method to the hatchet agent
mrkaye97 Sep 19, 2025
86e4574
feat: add run_stream and iter methods
mrkaye97 Sep 19, 2025
7cc44ed
fix: return list of workflows so we don't need a cast
mrkaye97 Sep 19, 2025
551bd44
fix: add hacks around run_stream from inside of a task
mrkaye97 Sep 25, 2025
89bb21c
fix: start implementing event stream handler for agent properly
mrkaye97 Sep 25, 2025
a7bb7a1
fix: recursion
mrkaye97 Sep 25, 2025
deab728
feat: streaming impl, part i
mrkaye97 Sep 25, 2025
d979729
fix: streaming
mrkaye97 Sep 25, 2025
006fb1c
hack: partially working streaming implementation
mrkaye97 Sep 25, 2025
8fb6d98
feat: incremental progress on streaming
mrkaye97 Sep 26, 2025
81841ca
feat: more incremental streaming progress
mrkaye97 Sep 26, 2025
23c3d05
fix: use temporal-style stream handler for now
mrkaye97 Oct 2, 2025
bcf7689
fix: generic
mrkaye97 Oct 2, 2025
c6d1a8f
feat: temporal-ish event stream handler
mrkaye97 Oct 2, 2025
399ea63
fix: stream handler
mrkaye97 Oct 2, 2025
aa41ad3
chore: lint
mrkaye97 Oct 2, 2025
cc53c21
chore: lock
mrkaye97 Oct 2, 2025
6f2d52f
fix: appease the type checker
mrkaye97 Oct 2, 2025
ea13ab9
chore: add testcontainers dev dep
mrkaye97 Oct 3, 2025
240e332
chore: lockfile
mrkaye97 Nov 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._agent import HatchetAgent
from ._mcp_server import HatchetMCPServer
from ._model import HatchetModel

__all__ = ['HatchetAgent', 'HatchetModel', 'HatchetMCPServer']
671 changes: 671 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_agent.py

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_function_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.runnables.workflow import Standalone
from pydantic import BaseModel, ConfigDict

from pydantic_ai.exceptions import UserError
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.toolsets import FunctionToolset, ToolsetTool

from ._mcp_server import CallToolInput
from ._run_context import HatchetRunContext
from ._toolset import HatchetWrapperToolset
from ._utils import TaskConfig


class ToolOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

result: Any


class HatchetFunctionToolset(HatchetWrapperToolset[AgentDepsT]):
"""A wrapper for FunctionToolset that integrates with Hatchet, turning tool calls into Hatchet tasks."""

def __init__(
self,
wrapped: FunctionToolset[AgentDepsT],
*,
hatchet: Hatchet,
task_name_prefix: str,
task_config: TaskConfig,
deps_type: type[AgentDepsT],
run_context_type: type[HatchetRunContext[AgentDepsT]] = HatchetRunContext[AgentDepsT],
):
super().__init__(wrapped)
self._task_config = task_config
self._task_name_prefix = task_name_prefix
self._hatchet = hatchet
self._tool_tasks: dict[str, Standalone[CallToolInput[AgentDepsT], ToolOutput]] = {}
self.run_context_type = run_context_type

for tool_name in wrapped.tools.keys():
task_name = f'{task_name_prefix}__function_tool__{tool_name}'

def make_tool_task(current_tool_name: str):
@hatchet.task(
name=task_name,
description=self._task_config.description,
input_validator=CallToolInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def tool_task(
input: CallToolInput[AgentDepsT],
_ctx: Context,
) -> ToolOutput:
run_context = self.run_context_type.deserialize_run_context(
input.serialized_run_context, deps=input.deps
)
tool = (await wrapped.get_tools(run_context))[current_tool_name]

result = await super(HatchetFunctionToolset, self).call_tool(
current_tool_name, input.tool_args, run_context, tool
)

return ToolOutput(result=result)

return tool_task

self._tool_tasks[tool_name] = make_tool_task(tool_name)

@property
def hatchet_tasks(self) -> list[Standalone[Any, Any]]:
"""Return the list of Hatchet tasks for this toolset."""
return list(self._tool_tasks.values())

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> Any:
if name not in self._tool_tasks:
raise UserError(
f'Tool {name!r} not found in toolset {self.id!r}. '
'Removing or renaming tools during an agent run is not supported with Hatchet.'
)

tool_task: Standalone[CallToolInput[AgentDepsT], ToolOutput] = self._tool_tasks[name]
serialized_run_context = self.run_context_type.serialize_run_context(ctx)

output = await tool_task.aio_run(
CallToolInput(
name=name,
tool_args=tool_args,
tool_def=tool.tool_def,
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return output.result
172 changes: 172 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_mcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.runnables.workflow import Standalone
from pydantic import BaseModel, ConfigDict

from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.toolsets.abstract import (
ToolDefinition,
ToolsetTool,
)

from ._run_context import HatchetRunContext, SerializedHatchetRunContext
from ._toolset import HatchetWrapperToolset
from ._utils import TaskConfig

if TYPE_CHECKING:
from pydantic_ai.mcp import MCPServer, ToolResult

T = TypeVar('T')


class GetToolsInput(BaseModel, Generic[AgentDepsT]):
model_config = ConfigDict(arbitrary_types_allowed=True)

serialized_run_context: SerializedHatchetRunContext
deps: AgentDepsT


class CallToolInput(BaseModel, Generic[AgentDepsT]):
model_config = ConfigDict(arbitrary_types_allowed=True)

name: str
tool_args: dict[str, Any]
tool_def: ToolDefinition

serialized_run_context: SerializedHatchetRunContext
deps: AgentDepsT


class CallToolOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

result: 'ToolResult'


class HatchetMCPServer(HatchetWrapperToolset[AgentDepsT], ABC):
"""A wrapper for MCPServer that integrates with Hatchet, turning call_tool and get_tools to Hatchet tasks."""

def __init__(
self,
wrapped: 'MCPServer',
*,
hatchet: Hatchet,
task_name_prefix: str,
task_config: TaskConfig,
deps_type: type[AgentDepsT],
run_context_type: type[HatchetRunContext[AgentDepsT]] = HatchetRunContext[AgentDepsT],
):
super().__init__(wrapped)
self._task_config = task_config
self._task_name_prefix = task_name_prefix
self._hatchet = hatchet
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
self._name = f'{task_name_prefix}__mcp_server{id_suffix}'
self.run_context_type: type[HatchetRunContext[AgentDepsT]] = run_context_type

@hatchet.task(
name=f'{self._name}.get_tools',
description=self._task_config.description,
input_validator=GetToolsInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def wrapped_get_tools_task(
input: GetToolsInput[AgentDepsT],
_ctx: Context,
) -> dict[str, ToolDefinition]:
run_context = self.run_context_type.deserialize_run_context(input.serialized_run_context, deps=input.deps)

# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
tools = await super(HatchetMCPServer, self).get_tools(run_context)

return {name: tool.tool_def for name, tool in tools.items()}

self.hatchet_wrapped_get_tools_task = wrapped_get_tools_task

@hatchet.task(
name=f'{self._name}.call_tool',
description=self._task_config.description,
input_validator=CallToolInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def wrapped_call_tool_task(
input: CallToolInput[AgentDepsT],
_ctx: Context,
) -> CallToolOutput[AgentDepsT]:
run_context = self.run_context_type.deserialize_run_context(input.serialized_run_context, deps=input.deps)
tool = self.tool_for_tool_def(input.tool_def)

result = await super(HatchetMCPServer, self).call_tool(input.name, input.tool_args, run_context, tool)

return CallToolOutput[AgentDepsT](result=result)

self.hatchet_wrapped_call_tool_task = wrapped_call_tool_task

@property
def hatchet_tasks(self) -> list[Standalone[Any, Any]]:
"""Return the list of Hatchet tasks for this toolset."""
return [
self.hatchet_wrapped_get_tools_task,
self.hatchet_wrapped_call_tool_task,
]

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
tool_defs = await self.hatchet_wrapped_get_tools_task.aio_run(
GetToolsInput(
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> 'ToolResult':
serialized_run_context = self.run_context_type.serialize_run_context(ctx)

wrapped_tool_output = await self.hatchet_wrapped_call_tool_task.aio_run(
CallToolInput(
name=name,
tool_args=tool_args,
tool_def=tool.tool_def,
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return wrapped_tool_output.result
Loading
Loading