diff --git a/sdk/agenta/sdk/decorators/tracing.py b/sdk/agenta/sdk/decorators/tracing.py index e94dd7f3f1..5e1e45ebe1 100644 --- a/sdk/agenta/sdk/decorators/tracing.py +++ b/sdk/agenta/sdk/decorators/tracing.py @@ -2,6 +2,8 @@ from typing import Callable, Optional, Any, Dict, List, Union +import warnings + from opentelemetry import context as otel_context from opentelemetry.context import attach, detach @@ -34,6 +36,37 @@ log = get_module_logger(__name__) +_PREINIT_INSTRUMENTATION_WARNING_EMITTED = False + + +def _is_tracing_initialized() -> bool: + singleton = getattr(ag, "DEFAULT_AGENTA_SINGLETON_INSTANCE", None) + tracing = getattr(singleton, "tracing", None) if singleton is not None else None + + return bool( + tracing is not None + and getattr(tracing, "tracer_provider", None) is not None + and getattr(tracing, "tracer", None) is not None + ) + + +def _warn_instrumentation_before_init_once(handler_name: str) -> None: + global _PREINIT_INSTRUMENTATION_WARNING_EMITTED # pylint: disable=global-statement + + if _PREINIT_INSTRUMENTATION_WARNING_EMITTED: + return + + _PREINIT_INSTRUMENTATION_WARNING_EMITTED = True + + message = ( + "Agenta SDK warning: an instrumented function was called before `ag.init()`.\n" + f"- Function: {handler_name}\n" + "- Impact: this call will run without Agenta tracing/export.\n" + "- Fix: call `ag.init()` once at startup (before invoking any `@ag.instrument()` / `@ag.workflow` code)." + ) + + warnings.warn(message, RuntimeWarning, stacklevel=4) + def _has_instrument(handler: Callable[..., Any]) -> bool: return bool(getattr(handler, "__has_instrument__", False)) @@ -80,12 +113,18 @@ def __call__(self, handler: Callable[..., Any]): is_sync_generator = isgeneratorfunction(handler) is_async_generator = isasyncgenfunction(handler) + handler_name = f"{getattr(handler, '__module__', '')}.{getattr(handler, '__qualname__', getattr(handler, '__name__', ''))}" + # ---- ASYNC GENERATOR ---- if is_async_generator: @wraps(handler) def astream_wrapper(*args, **kwargs): with tracing_context_manager(context=TracingContext.get()): + if not _is_tracing_initialized(): + _warn_instrumentation_before_init_once(handler_name) + return handler(*args, **kwargs) + # debug_otel_context("[BEFORE STREAM] [BEFORE SETUP]") captured_ctx = otel_context.get_current() @@ -154,6 +193,10 @@ async def wrapped_generator(): @wraps(handler) def stream_wrapper(*args, **kwargs): with tracing_context_manager(context=TracingContext.get()): + if not _is_tracing_initialized(): + _warn_instrumentation_before_init_once(handler_name) + return handler(*args, **kwargs) + self._parse_type_and_kind() token = self._attach_baggage() @@ -217,6 +260,10 @@ def wrapped_generator(): @wraps(handler) async def awrapper(*args, **kwargs): with tracing_context_manager(context=TracingContext.get()): + if not _is_tracing_initialized(): + _warn_instrumentation_before_init_once(handler_name) + return await handler(*args, **kwargs) + self._parse_type_and_kind() token = self._attach_baggage() @@ -250,6 +297,10 @@ async def awrapper(*args, **kwargs): @wraps(handler) def wrapper(*args, **kwargs): with tracing_context_manager(context=TracingContext.get()): + if not _is_tracing_initialized(): + _warn_instrumentation_before_init_once(handler_name) + return handler(*args, **kwargs) + self._parse_type_and_kind() token = self._attach_baggage() diff --git a/sdk/tests/unit/test_tracing_decorators.py b/sdk/tests/unit/test_tracing_decorators.py index 67ffc59da5..470ca2c3a7 100644 --- a/sdk/tests/unit/test_tracing_decorators.py +++ b/sdk/tests/unit/test_tracing_decorators.py @@ -46,9 +46,11 @@ import pytest import asyncio -from unittest.mock import Mock, MagicMock, patch +from types import SimpleNamespace +from unittest.mock import Mock, patch from agenta.sdk.decorators.tracing import instrument +import agenta.sdk.decorators.tracing as tracing_decorators class TestExistingFunctionality: @@ -680,3 +682,102 @@ def parameterized_generator(prompt): # Verify span was set to OK status self.mock_span.set_status.assert_called_with("OK") + + +class TestPreInitInstrumentationWarnings: + def setup_method(self): + tracing_decorators._PREINIT_INSTRUMENTATION_WARNING_EMITTED = ( + False # pylint: disable=protected-access + ) + + @patch( + "agenta.sdk.decorators.tracing.ag", + new=SimpleNamespace( + DEFAULT_AGENTA_SINGLETON_INSTANCE=SimpleNamespace(tracing=None), + ), + ) + def test_sync_function_warns_once_and_still_executes(self, recwarn): + @instrument() + def add(x, y): + return x + y + + assert add(1, 2) == 3 + assert add(2, 3) == 5 + + runtime_warnings = [ + w for w in recwarn if issubclass(w.category, RuntimeWarning) + ] + assert len(runtime_warnings) == 1 + message = str(runtime_warnings[0].message) + assert "called before `ag.init()`" in message + assert "Fix: call `ag.init()`" in message + + @pytest.mark.asyncio + @patch( + "agenta.sdk.decorators.tracing.ag", + new=SimpleNamespace( + DEFAULT_AGENTA_SINGLETON_INSTANCE=SimpleNamespace(tracing=None), + ), + ) + async def test_async_function_warns_once_and_still_executes(self, recwarn): + @instrument() + async def mul(x, y): + await asyncio.sleep(0.001) + return x * y + + assert await mul(2, 3) == 6 + assert await mul(3, 4) == 12 + + runtime_warnings = [ + w for w in recwarn if issubclass(w.category, RuntimeWarning) + ] + assert len(runtime_warnings) == 1 + + @patch( + "agenta.sdk.decorators.tracing.ag", + new=SimpleNamespace( + DEFAULT_AGENTA_SINGLETON_INSTANCE=SimpleNamespace(tracing=None), + ), + ) + def test_sync_generator_warns_once_and_still_executes(self, recwarn): + @instrument() + def gen(): + yield "a" + yield "b" + + assert list(gen()) == ["a", "b"] + assert list(gen()) == ["a", "b"] + + runtime_warnings = [ + w for w in recwarn if issubclass(w.category, RuntimeWarning) + ] + assert len(runtime_warnings) == 1 + + @pytest.mark.asyncio + @patch( + "agenta.sdk.decorators.tracing.ag", + new=SimpleNamespace( + DEFAULT_AGENTA_SINGLETON_INSTANCE=SimpleNamespace(tracing=None), + ), + ) + async def test_async_generator_warns_once_and_still_executes(self, recwarn): + @instrument() + async def gen(): + yield "a" + await asyncio.sleep(0.001) + yield "b" + + first = [] + async for item in gen(): + first.append(item) + assert first == ["a", "b"] + + second = [] + async for item in gen(): + second.append(item) + assert second == ["a", "b"] + + runtime_warnings = [ + w for w in recwarn if issubclass(w.category, RuntimeWarning) + ] + assert len(runtime_warnings) == 1