Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions pydantic_evals/pydantic_evals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

import asyncio
import inspect
from collections.abc import Awaitable, Callable, Sequence
import warnings
from collections.abc import Awaitable, Callable, Generator, Sequence
from contextlib import contextmanager
from functools import partial
from typing import Any, TypeVar
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar

import anyio
import logfire_api
from typing_extensions import ParamSpec, TypeIs

_logfire = logfire_api.Logfire(otel_scope='pydantic-evals')
logfire_api.add_non_user_code_prefix(Path(__file__).parent.absolute())


class Unset:
"""A singleton to represent an unset value.
Expand Down Expand Up @@ -101,3 +108,28 @@ async def _run_task(tsk: Callable[[], Awaitable[T]], index: int) -> None:
tg.start_soon(_run_task, task, i)

return results


try:
from logfire._internal.config import (
LogfireNotConfiguredWarning, # pyright: ignore[reportAssignmentType,reportPrivateImportUsage]
)
# TODO: Remove this `pragma: no cover` once we test evals without pydantic-ai (which includes logfire)
except ImportError: # pragma: no cover

class LogfireNotConfiguredWarning(UserWarning):
pass


if TYPE_CHECKING:
logfire_span = _logfire.span
else:

@contextmanager
def logfire_span(*args: Any, **kwargs: Any) -> Generator[logfire_api.LogfireSpan, None, None]:
"""Create a Logfire span without warning if logfire is not configured."""
# TODO: Remove once Logfire has the ability to suppress this warning from non-user code
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=LogfireNotConfiguredWarning)
with _logfire.span(*args, **kwargs) as span:
yield span
8 changes: 4 additions & 4 deletions pydantic_evals/pydantic_evals/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from pydantic_evals._utils import get_event_loop

from ._utils import get_unwrapped_function_name, task_group_gather
from ._utils import get_unwrapped_function_name, logfire_span, task_group_gather
from .evaluators import EvaluationResult, Evaluator
from .evaluators._run_evaluator import run_evaluator
from .evaluators.common import DEFAULT_EVALUATORS
Expand Down Expand Up @@ -283,7 +283,7 @@ async def evaluate(
limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()

with (
_logfire.span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span,
logfire_span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span,
progress_bar or nullcontext(),
):
task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None
Expand Down Expand Up @@ -858,7 +858,7 @@ async def _run_once():
token = _CURRENT_TASK_RUN.set(task_run_)
try:
with (
_logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
logfire_span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
context_subtree() as span_tree_,
):
t0 = time.perf_counter()
Expand Down Expand Up @@ -933,7 +933,7 @@ async def _run_task_and_evaluators(
trace_id: str | None = None
span_id: str | None = None
try:
with _logfire.span(
with logfire_span(
'case: {case_name}',
task_name=get_unwrapped_function_name(task),
case_name=report_case_name,
Expand Down
8 changes: 3 additions & 5 deletions pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import traceback
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any

import logfire_api
from pydantic import (
TypeAdapter,
ValidationError,
)
from typing_extensions import TypeVar

from pydantic_evals._utils import logfire_span

from .context import EvaluatorContext
from .evaluator import (
EvaluationReason,
Expand All @@ -25,8 +25,6 @@
if TYPE_CHECKING:
from pydantic_ai.retries import RetryConfig

_logfire = logfire_api.Logfire(otel_scope='pydantic-evals')
logfire_api.add_non_user_code_prefix(Path(__file__).parent.absolute())

InputsT = TypeVar('InputsT', default=Any, contravariant=True)
OutputT = TypeVar('OutputT', default=Any, contravariant=True)
Expand Down Expand Up @@ -62,7 +60,7 @@ async def run_evaluator(
evaluate = tenacity_retry(**retry)(evaluate)

try:
with _logfire.span(
with logfire_span(
'evaluator: {evaluator_name}',
evaluator_name=evaluator.get_default_evaluation_name(),
):
Expand Down
31 changes: 29 additions & 2 deletions pydantic_graph/pydantic_graph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

import asyncio
import types
from collections.abc import Callable
import warnings
from collections.abc import Callable, Generator
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, get_args, get_origin

from logfire_api import LogfireSpan
from logfire_api import Logfire, LogfireSpan
from typing_extensions import ParamSpec, TypeIs
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin

if TYPE_CHECKING:
from opentelemetry.trace import Span

_logfire = Logfire(otel_scope='pydantic-graph')

AbstractSpan: TypeAlias = 'LogfireSpan | Span'

Expand Down Expand Up @@ -136,3 +139,27 @@ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.k
return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs))
else:
return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore


try:
from logfire._internal.config import (
LogfireNotConfiguredWarning, # pyright: ignore[reportAssignmentType,reportPrivateImportUsage]
)
except ImportError:

class LogfireNotConfiguredWarning(UserWarning):
pass


if TYPE_CHECKING:
logfire_span = _logfire.span
else:

@contextmanager
def logfire_span(*args: Any, **kwargs: Any) -> Generator[LogfireSpan, None, None]:
"""Create a Logfire span without warning if logfire is not configured."""
# TODO: Remove once Logfire has the ability to suppress this warning from non-user code
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=LogfireNotConfiguredWarning)
with _logfire.span(*args, **kwargs) as span:
yield span
11 changes: 4 additions & 7 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,17 @@
from pathlib import Path
from typing import Any, Generic, cast, overload

import logfire_api
import typing_extensions
from typing_inspection import typing_objects

from . import _utils, exceptions, mermaid
from ._utils import AbstractSpan, get_traceparent
from ._utils import AbstractSpan, get_traceparent, logfire_span
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT
from .persistence import BaseStatePersistence
from .persistence.in_mem import SimpleStatePersistence

__all__ = 'Graph', 'GraphRun', 'GraphRunResult'

_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')


@dataclass(init=False)
class Graph(Generic[StateT, DepsT, RunEndT]):
Expand Down Expand Up @@ -242,7 +239,7 @@ async def iter(
entered_span: AbstractSpan | None = None
if span is None:
if self.auto_instrument:
entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self))
entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self))
else:
entered_span = stack.enter_context(span)
traceparent = None if entered_span is None else get_traceparent(entered_span)
Expand Down Expand Up @@ -291,7 +288,7 @@ async def iter_from_persistence(
snapshot.node.set_snapshot_id(snapshot.id)

if self.auto_instrument and span is None: # pragma: no branch
span = logfire_api.span('run graph {graph.name}', graph=self)
span = logfire_span('run graph {graph.name}', graph=self)

with ExitStack() as stack:
entered_span = None if span is None else stack.enter_context(span)
Expand Down Expand Up @@ -727,7 +724,7 @@ async def main():

with ExitStack() as stack:
if self.graph.auto_instrument:
stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node))
stack.enter_context(logfire_span('run node {node_id}', node_id=node_id, node=node))

async with self.persistence.record_run(node_snapshot_id):
ctx = GraphRunContext(state=self.state, deps=self.deps)
Expand Down