Skip to content
Draft
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
080d2bc
chore: scaffold hatchet agent, model, and mcp server
mrkaye97 Sep 16, 2025
1aa689c
chore: add hatchet dep
mrkaye97 Sep 16, 2025
5e8dedf
feat: add utils for task config
mrkaye97 Sep 16, 2025
d3d818a
feat: first pass at model impl
mrkaye97 Sep 16, 2025
16f6e6a
feat: tool calling implementation for mcp server
mrkaye97 Sep 16, 2025
698dfdf
feat: implement request method for model
mrkaye97 Sep 16, 2025
db0d2ff
feat: more work on agent
mrkaye97 Sep 16, 2025
3ae9fcf
feat: agent run method impl
mrkaye97 Sep 16, 2025
937aa97
fix: allow arbitrary types in i/o models
mrkaye97 Sep 16, 2025
27036c7
fix: clean up types + workflow registration
mrkaye97 Sep 16, 2025
f993b69
feat: first pass at toolsets
mrkaye97 Sep 16, 2025
ee5e872
feat: hatchet run context
mrkaye97 Sep 18, 2025
fecdc7e
chore: ignore local files for testing
mrkaye97 Sep 18, 2025
052276f
chore: bump hatchet
mrkaye97 Sep 18, 2025
6476b40
feat: use hatchet run context in mcp server
mrkaye97 Sep 18, 2025
1fb3881
feat: stricter typing on run context
mrkaye97 Sep 18, 2025
e791cd5
fix: pass types around
mrkaye97 Sep 18, 2025
0626318
fix: make tool and tool def serializable a la temporal impl
mrkaye97 Sep 18, 2025
c89bc10
fix: comment
mrkaye97 Sep 18, 2025
9b7a2c3
feat: more work on making the hatchet function toolset impl more simi…
mrkaye97 Sep 18, 2025
d76c42c
fix: improve typing on tasks
mrkaye97 Sep 18, 2025
722f356
fix: use tasks for everything except the agent itself
mrkaye97 Sep 18, 2025
f093749
fix: tool naming
mrkaye97 Sep 18, 2025
10063c3
feat: add hatchet metadata on agent run
mrkaye97 Sep 18, 2025
fad04c8
feat: add run_sync method to the hatchet agent
mrkaye97 Sep 19, 2025
aff17c0
feat: add run_stream and iter methods
mrkaye97 Sep 19, 2025
756da33
fix: return list of workflows so we don't need a cast
mrkaye97 Sep 19, 2025
29e79cd
fix: add hacks around run_stream from inside of a task
mrkaye97 Sep 25, 2025
e012415
fix: start implementing event stream handler for agent properly
mrkaye97 Sep 25, 2025
2519554
fix: recursion
mrkaye97 Sep 25, 2025
68bb231
feat: streaming impl, part i
mrkaye97 Sep 25, 2025
84eac57
fix: streaming
mrkaye97 Sep 25, 2025
af7e924
hack: partially working streaming implementation
mrkaye97 Sep 25, 2025
7bfac33
feat: incremental progress on streaming
mrkaye97 Sep 26, 2025
4efad74
feat: more incremental streaming progress
mrkaye97 Sep 26, 2025
c77f051
fix: use temporal-style stream handler for now
mrkaye97 Oct 2, 2025
5a99248
fix: generic
mrkaye97 Oct 2, 2025
5b6a732
feat: temporal-ish event stream handler
mrkaye97 Oct 2, 2025
4286ddf
fix: stream handler
mrkaye97 Oct 2, 2025
fd02470
chore: lint
mrkaye97 Oct 2, 2025
29d3dae
Merge branch 'main' into mk/hatchet-durable-execution-backend
mrkaye97 Oct 2, 2025
5d8f971
chore: lock
mrkaye97 Oct 2, 2025
92e60ff
fix: appease the type checker
mrkaye97 Oct 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
local
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._agent import HatchetAgent
from ._mcp_server import HatchetMCPServer
from ._model import HatchetModel

__all__ = ['HatchetAgent', 'HatchetModel', 'HatchetMCPServer']
259 changes: 259 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from typing import Any, Generic, overload
from uuid import uuid4

from hatchet_sdk import DurableContext, Hatchet, TriggerWorkflowOptions
from hatchet_sdk.runnables.workflow import BaseWorkflow
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Never

from pydantic_ai import (
messages as _messages,
models,
usage as _usage,
)
from pydantic_ai.agent import AbstractAgent, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import Model
from pydantic_ai.output import OutputDataT, OutputSpec
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import (
AgentDepsT,
DeferredToolResults,
)
from pydantic_ai.toolsets import AbstractToolset

from ._model import HatchetModel
from ._run_context import HatchetRunContext
from ._utils import TaskConfig


class RunAgentInput(BaseModel, Generic[RunOutputDataT, AgentDepsT]):
model_config = ConfigDict(arbitrary_types_allowed=True)

user_prompt: str | Sequence[_messages.UserContent] | None = None
output_type: OutputSpec[RunOutputDataT] | None = None
message_history: list[_messages.ModelMessage] | None = None
deferred_tool_results: DeferredToolResults | None = None
model: models.Model | models.KnownModelName | str | None = None
deps: AgentDepsT
model_settings: ModelSettings | None = None
usage_limits: _usage.UsageLimits | None = None
usage: _usage.RunUsage | None = None
infer_name: bool = True
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None
deprecated_kwargs: dict[str, Any] = Field(default_factory=dict)


class HatchetAgent(WrapperAgent[AgentDepsT, OutputDataT]):
def __init__(
self,
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
hatchet: Hatchet,
*,
name: str | None = None,
mcp_task_config: TaskConfig | None = None,
model_task_config: TaskConfig | None = None,
run_context_type: type[HatchetRunContext[AgentDepsT]] = HatchetRunContext[AgentDepsT],
):
"""Wrap an agent to enable it with Hatchet durable tasks, by automatically offloading model requests, tool calls, and MCP server communication to Hatchet tasks.

After wrapping, the original agent can still be used as normal outside of the Hatchet workflow.

Args:
wrapped: The agent to wrap.
hatchet: The Hatchet instance to use for creating tasks.
name: Optional unique agent name to use in the Hatchet tasks' names. If not provided, the agent's `name` will be used.
mcp_task_config: The base Hatchet task config to use for MCP server tasks. If no config is provided, use the default settings.
model_task_config: The Hatchet task config to use for model request tasks. If no config is provided, use the default settings.
run_context_type: The `HatchetRunContext` (sub)class that's used to serialize and deserialize the run context.
"""
super().__init__(wrapped)

self._name = name or wrapped.name
self._hatchet = hatchet

if not self._name:
raise UserError(
"An agent needs to have a unique `name` in order to be used with Hatchet. The name will be used to identify the agent's workflows and tasks."
)

if not isinstance(wrapped.model, Model):
raise UserError(
'An agent needs to have a `model` in order to be used with Hatchet, it cannot be set at agent run time.'
)

self._model = HatchetModel(
wrapped.model,
task_name_prefix=self._name,
task_config=model_task_config or TaskConfig(),
hatchet=self._hatchet,
)
hatchet_agent_name = self._name
self.run_context_type: type[HatchetRunContext[AgentDepsT]] = run_context_type

def hatchetify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
from ._toolset import hatchetize_toolset

return hatchetize_toolset(
toolset,
hatchet=hatchet,
task_name_prefix=hatchet_agent_name,
task_config=mcp_task_config or TaskConfig(),
deps_type=self.deps_type,
run_context_type=run_context_type,
)

self._toolsets = [toolset.visit_and_replace(hatchetify_toolset) for toolset in wrapped.toolsets]

@hatchet.durable_task(name=f'{self._name}.run', input_validator=RunAgentInput[Any, Any])
async def wrapped_run_workflow(
input: RunAgentInput[RunOutputDataT, AgentDepsT],
_ctx: DurableContext,
) -> AgentRunResult[Any]:
with self._hatchet_overrides():
return await super(WrapperAgent, self).run(
input.user_prompt,
output_type=input.output_type,
message_history=input.message_history,
deferred_tool_results=input.deferred_tool_results,
model=input.model,
deps=input.deps,
model_settings=input.model_settings,
usage_limits=input.usage_limits,
usage=input.usage,
infer_name=input.infer_name,
toolsets=input.toolsets,
event_stream_handler=input.event_stream_handler,
**input.deprecated_kwargs,
)

self.hatchet_wrapped_run_workflow = wrapped_run_workflow

@property
def name(self) -> str | None:
return self._name

@name.setter
def name(self, value: str | None) -> None: # pragma: no cover
raise UserError(
'The agent name cannot be changed after creation. If you need to change the name, create a new agent.'
)

@property
def model(self) -> Model:
return self._model

@property
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
with self._hatchet_overrides():
return super().toolsets

@contextmanager
def _hatchet_overrides(self) -> Iterator[None]:
with super().override(model=self._model, toolsets=self._toolsets, tools=[]):
yield

@property
def workflows(self) -> Sequence[BaseWorkflow[Any]]:
workflows: list[BaseWorkflow[Any]] = [
self.hatchet_wrapped_run_workflow,
self._model.hatchet_wrapped_request_task,
]

for toolset in self._toolsets:
from ._toolset import HatchetWrapperToolset

if isinstance(toolset, HatchetWrapperToolset):
workflows.extend(toolset.hatchet_tasks)

return workflows

@overload
async def run(
self,
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
model_settings: ModelSettings | None = None,
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AgentRunResult[OutputDataT]: ...

@overload
async def run(
self,
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT],
message_history: list[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
model_settings: ModelSettings | None = None,
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AgentRunResult[RunOutputDataT]: ...

async def run(
self,
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
output_type: OutputSpec[RunOutputDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
deferred_tool_results: DeferredToolResults | None = None,
model: models.Model | models.KnownModelName | str | None = None,
deps: AgentDepsT = None,
model_settings: ModelSettings | None = None,
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
**_deprecated_kwargs: Never,
) -> AgentRunResult[Any]:
agent_run_id = uuid4()

"""Run the agent with a user prompt in async mode."""
result = await self.hatchet_wrapped_run_workflow.aio_run(
RunAgentInput[RunOutputDataT, AgentDepsT](
user_prompt=user_prompt,
output_type=output_type,
message_history=message_history,
deferred_tool_results=deferred_tool_results,
model=model,
deps=deps,
model_settings=model_settings,
usage_limits=usage_limits,
usage=usage,
infer_name=infer_name,
toolsets=toolsets,
event_stream_handler=event_stream_handler,
deprecated_kwargs=_deprecated_kwargs,
),
options=TriggerWorkflowOptions(
additional_metadata={
'hatchet__agent_name': self._name,
'hatchet__agent_run_id': str(agent_run_id),
}
),
)

if isinstance(result, dict):
return TypeAdapter(AgentRunResult[Any]).validate_python(result)

return result
114 changes: 114 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/hatchet/_function_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.runnables.workflow import Standalone
from pydantic import BaseModel, ConfigDict

from pydantic_ai.exceptions import UserError
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.toolsets import FunctionToolset, ToolsetTool

from ._mcp_server import CallToolInput
from ._run_context import HatchetRunContext
from ._toolset import HatchetWrapperToolset
from ._utils import TaskConfig


class ToolOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

result: Any


class HatchetFunctionToolset(HatchetWrapperToolset[AgentDepsT]):
"""A wrapper for FunctionToolset that integrates with Hatchet, turning tool calls into Hatchet tasks."""

def __init__(
self,
wrapped: FunctionToolset[AgentDepsT],
*,
hatchet: Hatchet,
task_name_prefix: str,
task_config: TaskConfig,
deps_type: type[AgentDepsT],
run_context_type: type[HatchetRunContext[AgentDepsT]] = HatchetRunContext[AgentDepsT],
):
super().__init__(wrapped)
self._task_config = task_config
self._task_name_prefix = task_name_prefix
self._hatchet = hatchet
self._tool_tasks: dict[str, Standalone[CallToolInput[AgentDepsT], ToolOutput]] = {}
self.run_context_type = run_context_type

for tool_name in wrapped.tools.keys():
task_name = f'{task_name_prefix}__function_tool__{tool_name}'

def make_tool_task(current_tool_name: str):
@hatchet.task(
name=task_name,
description=self._task_config.description,
input_validator=CallToolInput[AgentDepsT],
version=self._task_config.version,
sticky=self._task_config.sticky,
default_priority=self._task_config.default_priority,
concurrency=self._task_config.concurrency,
schedule_timeout=self._task_config.schedule_timeout,
execution_timeout=self._task_config.execution_timeout,
retries=self._task_config.retries,
rate_limits=self._task_config.rate_limits,
desired_worker_labels=self._task_config.desired_worker_labels,
backoff_factor=self._task_config.backoff_factor,
backoff_max_seconds=self._task_config.backoff_max_seconds,
default_filters=self._task_config.default_filters,
)
async def tool_task(
input: CallToolInput[AgentDepsT],
_ctx: Context,
) -> ToolOutput:
run_context = self.run_context_type.deserialize_run_context(
input.serialized_run_context, deps=input.deps
)
tool = (await wrapped.get_tools(run_context))[current_tool_name]

result = await super(HatchetFunctionToolset, self).call_tool(
current_tool_name, input.tool_args, run_context, tool
)

return ToolOutput(result=result)

return tool_task

self._tool_tasks[tool_name] = make_tool_task(tool_name)

@property
def hatchet_tasks(self) -> list[Standalone[Any, Any]]:
"""Return the list of Hatchet tasks for this toolset."""
return list(self._tool_tasks.values())

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> Any:
if name not in self._tool_tasks:
raise UserError(
f'Tool {name!r} not found in toolset {self.id!r}. '
'Removing or renaming tools during an agent run is not supported with Hatchet.'
)

tool_task: Standalone[CallToolInput[AgentDepsT], ToolOutput] = self._tool_tasks[name]
serialized_run_context = self.run_context_type.serialize_run_context(ctx)

output = await tool_task.aio_run(
CallToolInput(
name=name,
tool_args=tool_args,
tool_def=tool.tool_def,
serialized_run_context=serialized_run_context,
deps=ctx.deps,
)
)

return output.result
Loading