Skip to content

Commit a0a1294

Browse files
authored
Merge branch 'main' into patch-3323
2 parents c5b3b57 + e7b2f82 commit a0a1294

File tree

10 files changed

+527
-47
lines changed

10 files changed

+527
-47
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656

5757
try:
58-
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream, omit as OMIT
58+
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropicBedrock, AsyncStream, omit as OMIT
5959
from anthropic.types.beta import (
6060
BetaBase64PDFBlockParam,
6161
BetaBase64PDFSourceParam,
@@ -76,6 +76,7 @@
7676
BetaMemoryTool20250818Param,
7777
BetaMessage,
7878
BetaMessageParam,
79+
BetaMessageTokensCount,
7980
BetaMetadataParam,
8081
BetaPlainTextSourceParam,
8182
BetaRawContentBlockDeltaEvent,
@@ -239,6 +240,23 @@ async def request(
239240
model_response = self._process_response(response)
240241
return model_response
241242

243+
async def count_tokens(
244+
self,
245+
messages: list[ModelMessage],
246+
model_settings: ModelSettings | None,
247+
model_request_parameters: ModelRequestParameters,
248+
) -> usage.RequestUsage:
249+
model_settings, model_request_parameters = self.prepare_request(
250+
model_settings,
251+
model_request_parameters,
252+
)
253+
254+
response = await self._messages_count_tokens(
255+
messages, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
256+
)
257+
258+
return usage.RequestUsage(input_tokens=response.input_tokens)
259+
242260
@asynccontextmanager
243261
async def request_stream(
244262
self,
@@ -310,28 +328,12 @@ async def _messages_create(
310328
tools = self._get_tools(model_request_parameters, model_settings)
311329
tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters)
312330

313-
tool_choice: BetaToolChoiceParam | None
314-
315-
if not tools:
316-
tool_choice = None
317-
else:
318-
if not model_request_parameters.allow_text_output:
319-
tool_choice = {'type': 'any'}
320-
else:
321-
tool_choice = {'type': 'auto'}
322-
323-
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
324-
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
331+
tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)
325332

326333
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
327334

328335
try:
329-
extra_headers = model_settings.get('extra_headers', {})
330-
extra_headers.setdefault('User-Agent', get_user_agent())
331-
if beta_features:
332-
if 'anthropic-beta' in extra_headers:
333-
beta_features.insert(0, extra_headers['anthropic-beta'])
334-
extra_headers['anthropic-beta'] = ','.join(beta_features)
336+
extra_headers = self._map_extra_headers(beta_features, model_settings)
335337

336338
return await self.client.beta.messages.create(
337339
max_tokens=model_settings.get('max_tokens', 4096),
@@ -356,6 +358,43 @@ async def _messages_create(
356358
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
357359
raise # pragma: lax no cover
358360

361+
async def _messages_count_tokens(
362+
self,
363+
messages: list[ModelMessage],
364+
model_settings: AnthropicModelSettings,
365+
model_request_parameters: ModelRequestParameters,
366+
) -> BetaMessageTokensCount:
367+
if isinstance(self.client, AsyncAnthropicBedrock):
368+
raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.')
369+
370+
# standalone function to make it easier to override
371+
tools = self._get_tools(model_request_parameters, model_settings)
372+
tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters)
373+
374+
tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)
375+
376+
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
377+
378+
try:
379+
extra_headers = self._map_extra_headers(beta_features, model_settings)
380+
381+
return await self.client.beta.messages.count_tokens(
382+
system=system_prompt or OMIT,
383+
messages=anthropic_messages,
384+
model=self._model_name,
385+
tools=tools or OMIT,
386+
tool_choice=tool_choice or OMIT,
387+
mcp_servers=mcp_servers or OMIT,
388+
thinking=model_settings.get('anthropic_thinking', OMIT),
389+
timeout=model_settings.get('timeout', NOT_GIVEN),
390+
extra_headers=extra_headers,
391+
extra_body=model_settings.get('extra_body'),
392+
)
393+
except APIStatusError as e:
394+
if (status_code := e.status_code) >= 400:
395+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
396+
raise # pragma: lax no cover
397+
359398
def _process_response(self, response: BetaMessage) -> ModelResponse:
360399
"""Process a non-streamed response, and prepare a message to return."""
361400
items: list[ModelResponsePart] = []
@@ -492,6 +531,37 @@ def _add_builtin_tools(
492531
)
493532
return tools, mcp_servers, beta_features
494533

534+
def _infer_tool_choice(
535+
self,
536+
tools: list[BetaToolUnionParam],
537+
model_settings: AnthropicModelSettings,
538+
model_request_parameters: ModelRequestParameters,
539+
) -> BetaToolChoiceParam | None:
540+
if not tools:
541+
return None
542+
else:
543+
tool_choice: BetaToolChoiceParam
544+
545+
if not model_request_parameters.allow_text_output:
546+
tool_choice = {'type': 'any'}
547+
else:
548+
tool_choice = {'type': 'auto'}
549+
550+
if 'parallel_tool_calls' in model_settings:
551+
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']
552+
553+
return tool_choice
554+
555+
def _map_extra_headers(self, beta_features: list[str], model_settings: AnthropicModelSettings) -> dict[str, str]:
556+
"""Apply beta_features to extra_headers in model_settings."""
557+
extra_headers = model_settings.get('extra_headers', {})
558+
extra_headers.setdefault('User-Agent', get_user_agent())
559+
if beta_features:
560+
if 'anthropic-beta' in extra_headers:
561+
beta_features.insert(0, extra_headers['anthropic-beta'])
562+
extra_headers['anthropic-beta'] = ','.join(beta_features)
563+
return extra_headers
564+
495565
async def _map_message( # noqa: C901
496566
self,
497567
messages: list[ModelMessage],

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
88

99
from pydantic_graph import BaseNode, End, GraphRunContext
10-
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
10+
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem
1111
from pydantic_graph.beta.step import NodeStep
1212

1313
from . import (
@@ -181,7 +181,7 @@ async def __anext__(
181181
return self._task_to_node(task)
182182

183183
def _task_to_node(
184-
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
184+
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest]
185185
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
186186
if isinstance(task, Sequence) and len(task) == 1:
187187
first_task = task[0]
@@ -197,8 +197,8 @@ def _task_to_node(
197197
return End(task.value)
198198
raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
199199

200-
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
201-
return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
200+
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest:
201+
return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=())
202202

203203
async def next(
204204
self,

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from __future__ import annotations as _annotations
99

1010
import sys
11-
import uuid
12-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
11+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
1312
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
1413
from dataclasses import dataclass, field
1514
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
@@ -22,7 +21,7 @@
2221
from pydantic_graph import exceptions
2322
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
2423
from pydantic_graph.beta.decision import Decision
25-
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
24+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
2625
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
2726
from pydantic_graph.beta.node import (
2827
EndNode,
@@ -306,14 +305,13 @@ def __str__(self) -> str:
306305

307306

308307
@dataclass
309-
class GraphTask:
310-
"""A single task representing the execution of a node in the graph.
308+
class GraphTaskRequest:
309+
"""A request to run a task representing the execution of a node in the graph.
311310
312-
GraphTask encapsulates all the information needed to execute a specific
311+
GraphTaskRequest encapsulates all the information needed to execute a specific
313312
node, including its inputs and the fork context it's executing within.
314313
"""
315314

316-
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317315
node_id: NodeID
318316
"""The ID of the node to execute."""
319317

@@ -326,9 +324,26 @@ class GraphTask:
326324
Used by the GraphRun to decide when to proceed through joins.
327325
"""
328326

329-
task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)
327+
328+
@dataclass
329+
class GraphTask(GraphTaskRequest):
330+
"""A task representing the execution of a node in the graph.
331+
332+
GraphTask encapsulates all the information needed to execute a specific
333+
node, including its inputs and the fork context it's executing within,
334+
and has a unique ID to identify the task within the graph run.
335+
"""
336+
337+
task_id: TaskID = field(repr=False)
330338
"""Unique identifier for this task."""
331339

340+
@staticmethod
341+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
342+
# Don't call the get_task_id callable, this is already a task
343+
if isinstance(request, GraphTask):
344+
return request
345+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
346+
332347

333348
class GraphRun(Generic[StateT, DepsT, OutputT]):
334349
"""A single execution instance of a graph.
@@ -378,12 +393,20 @@ def __init__(
378393
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
379394
"""The next item to be processed."""
380395

381-
run_id = GraphRunID(str(uuid.uuid4()))
382-
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383-
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
396+
self._next_task_id = 0
397+
self._next_node_run_id = 0
398+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
399+
self._first_task = GraphTask(
400+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
401+
)
384402
self._iterator_task_group = create_task_group()
385403
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386-
self.graph, self.state, self.deps, self._iterator_task_group
404+
self.graph,
405+
self.state,
406+
self.deps,
407+
self._iterator_task_group,
408+
self._get_next_node_run_id,
409+
self._get_next_task_id,
387410
)
388411
self._iterator = self._iterator_instance.iter_graph(self._first_task)
389412

@@ -449,7 +472,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
449472
return self._next
450473

451474
async def next(
452-
self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
475+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
453476
) -> EndMarker[OutputT] | Sequence[GraphTask]:
454477
"""Advance the graph execution by one step.
455478
@@ -467,7 +490,10 @@ async def next(
467490
# if `next` is called before the `first_node` has run.
468491
await anext(self)
469492
if value is not None:
470-
self._next = value
493+
if isinstance(value, EndMarker):
494+
self._next = value
495+
else:
496+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
471497
return await anext(self)
472498

473499
@property
@@ -490,6 +516,16 @@ def output(self) -> OutputT | None:
490516
return self._next.value
491517
return None
492518

519+
def _get_next_task_id(self) -> TaskID:
520+
next_id = TaskID(f'task:{self._next_task_id}')
521+
self._next_task_id += 1
522+
return next_id
523+
524+
def _get_next_node_run_id(self) -> NodeRunID:
525+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
526+
self._next_node_run_id += 1
527+
return next_id
528+
493529

494530
@dataclass
495531
class _GraphTaskAsyncIterable:
@@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
510546
state: StateT
511547
deps: DepsT
512548
task_group: TaskGroup
549+
get_next_node_run_id: Callable[[], NodeRunID]
550+
get_next_task_id: Callable[[], TaskID]
513551

514552
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
515553
active_tasks: dict[TaskID, GraphTask] = field(init=False)
@@ -522,6 +560,7 @@ def __post_init__(self):
522560
self.active_tasks = {}
523561
self.active_reducers = {}
524562
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
563+
self._next_node_run_id = 1
525564

526565
async def iter_graph( # noqa C901
527566
self, first_task: GraphTask
@@ -782,12 +821,12 @@ def _handle_node(
782821
fork_stack: ForkStack,
783822
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
784823
if isinstance(next_node, StepNode):
785-
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
824+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
786825
elif isinstance(next_node, JoinNode):
787826
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
788827
elif isinstance(next_node, BaseNode):
789828
node_step = NodeStep(next_node.__class__)
790-
return [GraphTask(node_step.id, next_node, fork_stack)]
829+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
791830
elif isinstance(next_node, End):
792831
return EndMarker(next_node.data)
793832
else:
@@ -821,7 +860,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
821860
'These markers should be removed from paths during graph building'
822861
)
823862
if isinstance(item, DestinationMarker):
824-
return [GraphTask(item.destination_id, inputs, fork_stack)]
863+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
825864
elif isinstance(item, TransformMarker):
826865
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
827866
return self._handle_path(path.next_path, inputs, fork_stack)
@@ -853,7 +892,7 @@ def _handle_fork_edges(
853892
) # this should have already been ensured during graph building
854893

855894
new_tasks: list[GraphTask] = []
856-
node_run_id = NodeRunID(str(uuid.uuid4()))
895+
node_run_id = self.get_next_node_run_id()
857896
if node.is_map:
858897
# If the map specifies a downstream join id, eagerly create a join state for it
859898
if (join_id := node.downstream_join_id) is not None:

pydantic_graph/pydantic_graph/beta/id_types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
ForkID = NodeID
2525
"""Alias for NodeId when referring to fork nodes."""
2626

27-
GraphRunID = NewType('GraphRunID', str)
28-
"""Unique identifier for a complete graph execution run."""
29-
3027
TaskID = NewType('TaskID', str)
3128
"""Unique identifier for a task within the graph execution."""
3229

tests/graph/beta/test_graph_iteration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from pydantic_graph.beta import GraphBuilder, StepContext
11-
from pydantic_graph.beta.graph import EndMarker, GraphTask
11+
from pydantic_graph.beta.graph import EndMarker, GraphTask, GraphTaskRequest
1212
from pydantic_graph.beta.id_types import NodeID
1313
from pydantic_graph.beta.join import reduce_list_append
1414

@@ -400,7 +400,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int:
400400
# Get the fork_stack from the EndMarker's source
401401
fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else ()
402402

403-
new_task = GraphTask(
403+
new_task = GraphTaskRequest(
404404
node_id=NodeID('second_step'),
405405
inputs=event.value,
406406
fork_stack=fork_stack,

0 commit comments

Comments
 (0)