diff --git a/pydantic_evals/pydantic_evals/_utils.py b/pydantic_evals/pydantic_evals/_utils.py index 21dfbff790..267045d7d1 100644 --- a/pydantic_evals/pydantic_evals/_utils.py +++ b/pydantic_evals/pydantic_evals/_utils.py @@ -2,13 +2,20 @@ import asyncio import inspect -from collections.abc import Awaitable, Callable, Sequence +import warnings +from collections.abc import Awaitable, Callable, Generator, Sequence +from contextlib import contextmanager from functools import partial -from typing import Any, TypeVar +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeVar import anyio +import logfire_api from typing_extensions import ParamSpec, TypeIs +_logfire = logfire_api.Logfire(otel_scope='pydantic-evals') +logfire_api.add_non_user_code_prefix(Path(__file__).parent.absolute()) + class Unset: """A singleton to represent an unset value. @@ -101,3 +108,28 @@ async def _run_task(tsk: Callable[[], Awaitable[T]], index: int) -> None: tg.start_soon(_run_task, task, i) return results + + +try: + from logfire._internal.config import ( + LogfireNotConfiguredWarning, # pyright: ignore[reportAssignmentType,reportPrivateImportUsage] + ) +# TODO: Remove this `pragma: no cover` once we test evals without pydantic-ai (which includes logfire) +except ImportError: # pragma: no cover + + class LogfireNotConfiguredWarning(UserWarning): + pass + + +if TYPE_CHECKING: + logfire_span = _logfire.span +else: + + @contextmanager + def logfire_span(*args: Any, **kwargs: Any) -> Generator[logfire_api.LogfireSpan, None, None]: + """Create a Logfire span without warning if logfire is not configured.""" + # TODO: Remove once Logfire has the ability to suppress this warning from non-user code + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=LogfireNotConfiguredWarning) + with _logfire.span(*args, **kwargs) as span: + yield span diff --git a/pydantic_evals/pydantic_evals/dataset.py b/pydantic_evals/pydantic_evals/dataset.py index 80663b4f8d..dc2fb3ef44 100644 --- a/pydantic_evals/pydantic_evals/dataset.py +++ b/pydantic_evals/pydantic_evals/dataset.py @@ -36,7 +36,7 @@ from pydantic_evals._utils import get_event_loop -from ._utils import get_unwrapped_function_name, task_group_gather +from ._utils import get_unwrapped_function_name, logfire_span, task_group_gather from .evaluators import EvaluationResult, Evaluator from .evaluators._run_evaluator import run_evaluator from .evaluators.common import DEFAULT_EVALUATORS @@ -283,7 +283,7 @@ async def evaluate( limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack() with ( - _logfire.span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span, + logfire_span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span, progress_bar or nullcontext(), ): task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None @@ -858,7 +858,7 @@ async def _run_once(): token = _CURRENT_TASK_RUN.set(task_run_) try: with ( - _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span, + logfire_span('execute {task}', task=get_unwrapped_function_name(task)) as task_span, context_subtree() as span_tree_, ): t0 = time.perf_counter() @@ -933,7 +933,7 @@ async def _run_task_and_evaluators( trace_id: str | None = None span_id: str | None = None try: - with _logfire.span( + with logfire_span( 'case: {case_name}', task_name=get_unwrapped_function_name(task), case_name=report_case_name, diff --git a/pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py b/pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py index be2f803aee..c553ee6174 100644 --- a/pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py +++ b/pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py @@ -2,16 +2,16 @@ import traceback from collections.abc import Mapping -from pathlib import Path from typing import TYPE_CHECKING, Any -import logfire_api from pydantic import ( TypeAdapter, ValidationError, ) from typing_extensions import TypeVar +from pydantic_evals._utils import logfire_span + from .context import EvaluatorContext from .evaluator import ( EvaluationReason, @@ -25,8 +25,6 @@ if TYPE_CHECKING: from pydantic_ai.retries import RetryConfig -_logfire = logfire_api.Logfire(otel_scope='pydantic-evals') -logfire_api.add_non_user_code_prefix(Path(__file__).parent.absolute()) InputsT = TypeVar('InputsT', default=Any, contravariant=True) OutputT = TypeVar('OutputT', default=Any, contravariant=True) @@ -62,7 +60,7 @@ async def run_evaluator( evaluate = tenacity_retry(**retry)(evaluate) try: - with _logfire.span( + with logfire_span( 'evaluator: {evaluator_name}', evaluator_name=evaluator.get_default_evaluation_name(), ): diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 1b18a9cf3a..e725862fcf 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -2,11 +2,13 @@ import asyncio import types -from collections.abc import Callable +import warnings +from collections.abc import Callable, Generator +from contextlib import contextmanager from functools import partial from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, get_args, get_origin -from logfire_api import LogfireSpan +from logfire_api import Logfire, LogfireSpan from typing_extensions import ParamSpec, TypeIs from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin @@ -14,6 +16,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span +_logfire = Logfire(otel_scope='pydantic-graph') AbstractSpan: TypeAlias = 'LogfireSpan | Span' @@ -136,3 +139,27 @@ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.k return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs)) else: return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore + + +try: + from logfire._internal.config import ( + LogfireNotConfiguredWarning, # pyright: ignore[reportAssignmentType,reportPrivateImportUsage] + ) +except ImportError: + + class LogfireNotConfiguredWarning(UserWarning): + pass + + +if TYPE_CHECKING: + logfire_span = _logfire.span +else: + + @contextmanager + def logfire_span(*args: Any, **kwargs: Any) -> Generator[LogfireSpan, None, None]: + """Create a Logfire span without warning if logfire is not configured.""" + # TODO: Remove once Logfire has the ability to suppress this warning from non-user code + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=LogfireNotConfiguredWarning) + with _logfire.span(*args, **kwargs) as span: + yield span diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 6dfa8ab4c7..1f7795e743 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -9,20 +9,17 @@ from pathlib import Path from typing import Any, Generic, cast, overload -import logfire_api import typing_extensions from typing_inspection import typing_objects from . import _utils, exceptions, mermaid -from ._utils import AbstractSpan, get_traceparent +from ._utils import AbstractSpan, get_traceparent, logfire_span from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT from .persistence import BaseStatePersistence from .persistence.in_mem import SimpleStatePersistence __all__ = 'Graph', 'GraphRun', 'GraphRunResult' -_logfire = logfire_api.Logfire(otel_scope='pydantic-graph') - @dataclass(init=False) class Graph(Generic[StateT, DepsT, RunEndT]): @@ -242,7 +239,7 @@ async def iter( entered_span: AbstractSpan | None = None if span is None: if self.auto_instrument: - entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self)) + entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self)) else: entered_span = stack.enter_context(span) traceparent = None if entered_span is None else get_traceparent(entered_span) @@ -291,7 +288,7 @@ async def iter_from_persistence( snapshot.node.set_snapshot_id(snapshot.id) if self.auto_instrument and span is None: # pragma: no branch - span = logfire_api.span('run graph {graph.name}', graph=self) + span = logfire_span('run graph {graph.name}', graph=self) with ExitStack() as stack: entered_span = None if span is None else stack.enter_context(span) @@ -727,7 +724,7 @@ async def main(): with ExitStack() as stack: if self.graph.auto_instrument: - stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) + stack.enter_context(logfire_span('run node {node_id}', node_id=node_id, node=node)) async with self.persistence.record_run(node_snapshot_id): ctx = GraphRunContext(state=self.state, deps=self.deps)