Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
080d2bc
chore: scaffold hatchet agent, model, and mcp server
mrkaye97 Sep 16, 2025
1aa689c
chore: add hatchet dep
mrkaye97 Sep 16, 2025
5e8dedf
feat: add utils for task config
mrkaye97 Sep 16, 2025
d3d818a
feat: first pass at model impl
mrkaye97 Sep 16, 2025
16f6e6a
feat: tool calling implementation for mcp server
mrkaye97 Sep 16, 2025
698dfdf
feat: implement request method for model
mrkaye97 Sep 16, 2025
db0d2ff
feat: more work on agent
mrkaye97 Sep 16, 2025
3ae9fcf
feat: agent run method impl
mrkaye97 Sep 16, 2025
937aa97
fix: allow arbitrary types in i/o models
mrkaye97 Sep 16, 2025
27036c7
fix: clean up types + workflow registration
mrkaye97 Sep 16, 2025
f993b69
feat: first pass at toolsets
mrkaye97 Sep 16, 2025
ee5e872
feat: hatchet run context
mrkaye97 Sep 18, 2025
fecdc7e
chore: ignore local files for testing
mrkaye97 Sep 18, 2025
052276f
chore: bump hatchet
mrkaye97 Sep 18, 2025
6476b40
feat: use hatchet run context in mcp server
mrkaye97 Sep 18, 2025
1fb3881
feat: stricter typing on run context
mrkaye97 Sep 18, 2025
e791cd5
fix: pass types around
mrkaye97 Sep 18, 2025
0626318
fix: make tool and tool def serializable a la temporal impl
mrkaye97 Sep 18, 2025
c89bc10
fix: comment
mrkaye97 Sep 18, 2025
9b7a2c3
feat: more work on making the hatchet function toolset impl more simi…
mrkaye97 Sep 18, 2025
d76c42c
fix: improve typing on tasks
mrkaye97 Sep 18, 2025
722f356
fix: use tasks for everything except the agent itself
mrkaye97 Sep 18, 2025
f093749
fix: tool naming
mrkaye97 Sep 18, 2025
10063c3
feat: add hatchet metadata on agent run
mrkaye97 Sep 18, 2025
fad04c8
feat: add run_sync method to the hatchet agent
mrkaye97 Sep 19, 2025
aff17c0
feat: add run_stream and iter methods
mrkaye97 Sep 19, 2025
756da33
fix: return list of workflows so we don't need a cast
mrkaye97 Sep 19, 2025
29e79cd
fix: add hacks around run_stream from inside of a task
mrkaye97 Sep 25, 2025
e012415
fix: start implementing event stream handler for agent properly
mrkaye97 Sep 25, 2025
2519554
fix: recursion
mrkaye97 Sep 25, 2025
68bb231
feat: streaming impl, part i
mrkaye97 Sep 25, 2025
84eac57
fix: streaming
mrkaye97 Sep 25, 2025
af7e924
hack: partially working streaming implementation
mrkaye97 Sep 25, 2025
7bfac33
feat: incremental progress on streaming
mrkaye97 Sep 26, 2025
4efad74
feat: more incremental streaming progress
mrkaye97 Sep 26, 2025
c77f051
fix: use temporal-style stream handler for now
mrkaye97 Oct 2, 2025
5a99248
fix: generic
mrkaye97 Oct 2, 2025
5b6a732
feat: temporal-ish event stream handler
mrkaye97 Oct 2, 2025
4286ddf
fix: stream handler
mrkaye97 Oct 2, 2025
fd02470
chore: lint
mrkaye97 Oct 2, 2025
29d3dae
Merge branch 'main' into mk/hatchet-durable-execution-backend
mrkaye97 Oct 2, 2025
5d8f971
chore: lock
mrkaye97 Oct 2, 2025
92e60ff
fix: appease the type checker
mrkaye97 Oct 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._agent import HatchetAgent
from ._mcp_server import HatchetMCPServer
from ._model import HatchetModel

__all__ = ['HatchetAgent', 'HatchetModel', 'HatchetMCPServer']
671 changes: 671 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_agent.py

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_function_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.runnables.workflow import Standalone
from pydantic import BaseModel, ConfigDict

from pydantic_ai.exceptions import UserError
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.toolsets import FunctionToolset, ToolsetTool

from ._mcp_server import CallToolInput
from ._run_context import HatchetRunContext
from ._toolset import HatchetWrapperToolset
from ._utils import TaskConfig


class ToolOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

result: Any


class HatchetFunctionToolset(HatchetWrapperToolset[AgentDepsT]):
"""A wrapper for FunctionToolset that integrates with Hatchet, turning tool calls into Hatchet tasks."""

def __init__(
self,
wrapped: FunctionToolset[AgentDepsT],
*,
hatchet: Hatchet,
task_name_prefix: str,
task_config: TaskConfig,
deps_type: type[AgentDepsT],
run_context_type: type[HatchetRunContext[AgentDepsT]] = HatchetRunContext[AgentDepsT],
):
super().__init__(wrapped)
self._task_config = task_config
self._task_name_prefix = task_name_prefix
self._hatchet = hatchet
self._tool_tasks: dict[str, Standalone[CallToolInput[AgentDepsT], ToolOutput]] = {}
self.run_context_type = run_context_type

for tool_name in wrapped.tools.keys():
task_name = f'{task_name_prefix}__function_tool__{tool_name}'

def make_tool_task(current_tool_name: str):
@hatchet.task(
name=task_name,
description=self._task_config.description,
input_validator=CallToolInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def tool_task(
input: CallToolInput[AgentDepsT],
_ctx: Context,
) -> ToolOutput:
run_context = self.run_context_type.deserialize_run_context(
input.serialized_run_context, deps=input.deps
)
tool = (await wrapped.get_tools(run_context))[current_tool_name]

result = await super(HatchetFunctionToolset, self).call_tool(
current_tool_name, input.tool_args, run_context, tool
)

return ToolOutput(result=result)

return tool_task

self._tool_tasks[tool_name] = make_tool_task(tool_name)

@property
def hatchet_tasks(self) -> list[Standalone[Any, Any]]:
"""Return the list of Hatchet tasks for this toolset."""
return list(self._tool_tasks.values())

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> Any:
if name not in self._tool_tasks:
raise UserError(
f'Tool {name!r} not found in toolset {self.id!r}. '
'Removing or renaming tools during an agent run is not supported with Hatchet.'
)

tool_task: Standalone[CallToolInput[AgentDepsT], ToolOutput] = self._tool_tasks[name]
serialized_run_context = self.run_context_type.serialize_run_context(ctx)

output = await tool_task.aio_run(
CallToolInput(
name=name,
tool_args=tool_args,
tool_def=tool.tool_def,
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return output.result
172 changes: 172 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_mcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.runnables.workflow import Standalone
from pydantic import BaseModel, ConfigDict

from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.toolsets.abstract import (
ToolDefinition,
ToolsetTool,
)

from ._run_context import HatchetRunContext, SerializedHatchetRunContext
from ._toolset import HatchetWrapperToolset
from ._utils import TaskConfig

if TYPE_CHECKING:
from pydantic_ai.mcp import MCPServer, ToolResult

T = TypeVar('T')


class GetToolsInput(BaseModel, Generic[AgentDepsT]):
model_config = ConfigDict(arbitrary_types_allowed=True)

serialized_run_context: SerializedHatchetRunContext
deps: AgentDepsT


class CallToolInput(BaseModel, Generic[AgentDepsT]):
model_config = ConfigDict(arbitrary_types_allowed=True)

name: str
tool_args: dict[str, Any]
tool_def: ToolDefinition

serialized_run_context: SerializedHatchetRunContext
deps: AgentDepsT


class CallToolOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

result: 'ToolResult'


class HatchetMCPServer(HatchetWrapperToolset[AgentDepsT], ABC):
"""A wrapper for MCPServer that integrates with Hatchet, turning call_tool and get_tools to Hatchet tasks."""

def __init__(
self,
wrapped: 'MCPServer',
*,
hatchet: Hatchet,
task_name_prefix: str,
task_config: TaskConfig,
deps_type: type[AgentDepsT],
run_context_type: type[HatchetRunContext[AgentDepsT]] = HatchetRunContext[AgentDepsT],
):
super().__init__(wrapped)
self._task_config = task_config
self._task_name_prefix = task_name_prefix
self._hatchet = hatchet
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
self._name = f'{task_name_prefix}__mcp_server{id_suffix}'
self.run_context_type: type[HatchetRunContext[AgentDepsT]] = run_context_type

@hatchet.task(
name=f'{self._name}.get_tools',
description=self._task_config.description,
input_validator=GetToolsInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def wrapped_get_tools_task(
input: GetToolsInput[AgentDepsT],
_ctx: Context,
) -> dict[str, ToolDefinition]:
run_context = self.run_context_type.deserialize_run_context(input.serialized_run_context, deps=input.deps)

# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
tools = await super(HatchetMCPServer, self).get_tools(run_context)

return {name: tool.tool_def for name, tool in tools.items()}

self.hatchet_wrapped_get_tools_task = wrapped_get_tools_task

@hatchet.task(
name=f'{self._name}.call_tool',
description=self._task_config.description,
input_validator=CallToolInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def wrapped_call_tool_task(
input: CallToolInput[AgentDepsT],
_ctx: Context,
) -> CallToolOutput[AgentDepsT]:
run_context = self.run_context_type.deserialize_run_context(input.serialized_run_context, deps=input.deps)
tool = self.tool_for_tool_def(input.tool_def)

result = await super(HatchetMCPServer, self).call_tool(input.name, input.tool_args, run_context, tool)

return CallToolOutput[AgentDepsT](result=result)

self.hatchet_wrapped_call_tool_task = wrapped_call_tool_task

@property
def hatchet_tasks(self) -> list[Standalone[Any, Any]]:
"""Return the list of Hatchet tasks for this toolset."""
return [
self.hatchet_wrapped_get_tools_task,
self.hatchet_wrapped_call_tool_task,
]

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
tool_defs = await self.hatchet_wrapped_get_tools_task.aio_run(
GetToolsInput(
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> 'ToolResult':
serialized_run_context = self.run_context_type.serialize_run_context(ctx)

wrapped_tool_output = await self.hatchet_wrapped_call_tool_task.aio_run(
CallToolInput(
name=name,
tool_args=tool_args,
tool_def=tool.tool_def,
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return wrapped_tool_output.result
Loading
Loading