diff --git a/CHANGELOG-loongsuite.md b/CHANGELOG-loongsuite.md index 54df46d5..6a0c63ba 100644 --- a/CHANGELOG-loongsuite.md +++ b/CHANGELOG-loongsuite.md @@ -21,5 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Added +- `loongsuite-instrumentation-mem0`: add hook extension + ([#95](https://github.com/alibaba/loongsuite-python-agent/pull/95)) + - `loongsuite-instrumentation-mem0`: add support for mem0 ([#67](https://github.com/alibaba/loongsuite-python-agent/pull/67)) diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/__init__.py b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/__init__.py index 097db994..2be67006 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/__init__.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/__init__.py @@ -23,6 +23,7 @@ VectorStoreWrapper, ) from opentelemetry.instrumentation.mem0.package import _instruments +from opentelemetry.instrumentation.mem0.types import set_memory_hooks from opentelemetry.instrumentation.mem0.version import __version__ from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.schemas import Schemas @@ -127,6 +128,12 @@ def _instrument(self, **kwargs: Any) -> None: # Optional: logger provider for GenAI events (util will no-op if not provided) logger_provider = kwargs.get("logger_provider") + # Optional hooks for extensions (e.g. commercial metrics). We only pass through. + memory_before_hook = kwargs.get("memory_before_hook") + memory_after_hook = kwargs.get("memory_after_hook") + inner_before_hook = kwargs.get("inner_before_hook") + inner_after_hook = kwargs.get("inner_after_hook") + # Create util GenAI handler (strong dependency, no fallback). # Avoid singleton here so tests (and multiple tracer providers) don't leak across runs. telemetry_handler = ExtendedTelemetryHandler( @@ -144,13 +151,33 @@ def _instrument(self, **kwargs: Any) -> None: ) # Execute instrumentation (traces only, metrics removed) - self._instrument_memory_operations(telemetry_handler) - self._instrument_memory_client_operations(telemetry_handler) + self._instrument_memory_operations( + telemetry_handler, + memory_before_hook=memory_before_hook, + memory_after_hook=memory_after_hook, + ) + self._instrument_memory_client_operations( + telemetry_handler, + memory_before_hook=memory_before_hook, + memory_after_hook=memory_after_hook, + ) # Sub-phases controlled by toggle, avoid binding wrapper when disabled to reduce overhead if mem0_config.is_internal_phases_enabled(): - self._instrument_vector_operations(tracer) - self._instrument_graph_operations(tracer) - self._instrument_reranker_operations(tracer) + self._instrument_vector_operations( + tracer, + inner_before_hook=inner_before_hook, + inner_after_hook=inner_after_hook, + ) + self._instrument_graph_operations( + tracer, + inner_before_hook=inner_before_hook, + inner_after_hook=inner_after_hook, + ) + self._instrument_reranker_operations( + tracer, + inner_before_hook=inner_before_hook, + inner_after_hook=inner_after_hook, + ) def _uninstrument(self, **kwargs: Any) -> None: """Remove instrumentation.""" @@ -392,7 +419,13 @@ def _wrapper(wrapped, instance, args, kwargs): return _wrapper - def _instrument_memory_operations(self, telemetry_handler): + def _instrument_memory_operations( + self, + telemetry_handler, + *, + memory_before_hook=None, + memory_after_hook=None, + ): """Instrument Memory and AsyncMemory operations.""" try: if ( @@ -407,6 +440,11 @@ def _instrument_memory_operations(self, telemetry_handler): return wrapper = MemoryOperationWrapper(telemetry_handler) + set_memory_hooks( + wrapper, + memory_before_hook=memory_before_hook, + memory_after_hook=memory_after_hook, + ) # Instrument Memory (sync) for method in self._public_methods_of( @@ -444,7 +482,13 @@ def _instrument_memory_operations(self, telemetry_handler): except Exception as e: logger.debug(f"Failed to instrument Memory operations: {e}") - def _instrument_memory_client_operations(self, telemetry_handler): + def _instrument_memory_client_operations( + self, + telemetry_handler, + *, + memory_before_hook=None, + memory_after_hook=None, + ): """Instrument MemoryClient and AsyncMemoryClient operations.""" try: if ( @@ -459,6 +503,11 @@ def _instrument_memory_client_operations(self, telemetry_handler): return wrapper = MemoryOperationWrapper(telemetry_handler) + set_memory_hooks( + wrapper, + memory_before_hook=memory_before_hook, + memory_after_hook=memory_after_hook, + ) # Instrument MemoryClient (sync) for method in self._public_methods_of( @@ -594,7 +643,13 @@ def _factory_wrapper(wrapped, instance, args, kwargs): except Exception as e: logger.debug(f"Failed to wrap {factory_class}.create: {e}") - def _instrument_vector_operations(self, tracer): + def _instrument_vector_operations( + self, + tracer, + *, + inner_before_hook=None, + inner_after_hook=None, + ): """Instrument VectorStore operations.""" try: # Require both VectorStoreBase and VectorStoreFactory to be available @@ -632,7 +687,11 @@ def _instrument_vector_operations(self, tracer): ] # Create VectorStoreWrapper instance (trace-only) - vector_wrapper = VectorStoreWrapper(tracer) + vector_wrapper = VectorStoreWrapper( + tracer, + inner_before_hook=inner_before_hook, + inner_after_hook=inner_after_hook, + ) # Use generic factory wrapping method self._wrap_factory_for_phase( @@ -647,7 +706,13 @@ def _instrument_vector_operations(self, tracer): except Exception as e: logger.debug(f"Failed to instrument vector store operations: {e}") - def _instrument_graph_operations(self, tracer): + def _instrument_graph_operations( + self, + tracer, + *, + inner_before_hook=None, + inner_after_hook=None, + ): """Instrument GraphStore operations.""" try: # If factories are unavailable, graph subphase instrumentation cannot be enabled @@ -687,7 +752,11 @@ def _instrument_graph_operations(self, tracer): ] # Create GraphStoreWrapper instance (trace-only) - graph_wrapper = GraphStoreWrapper(tracer) + graph_wrapper = GraphStoreWrapper( + tracer, + inner_before_hook=inner_before_hook, + inner_after_hook=inner_after_hook, + ) # Use generic factory wrapping method self._wrap_factory_for_phase( @@ -702,7 +771,13 @@ def _instrument_graph_operations(self, tracer): except Exception as e: logger.debug(f"Failed to instrument graph store operations: {e}") - def _instrument_reranker_operations(self, tracer): + def _instrument_reranker_operations( + self, + tracer, + *, + inner_before_hook=None, + inner_after_hook=None, + ): """Instrument Reranker operations.""" try: if not _FACTORIES_AVAILABLE or RerankerFactory is None: @@ -713,7 +788,11 @@ def _instrument_reranker_operations(self, tracer): return # Create RerankerWrapper instance (trace-only) - reranker_wrapper = RerankerWrapper(tracer) + reranker_wrapper = RerankerWrapper( + tracer, + inner_before_hook=inner_before_hook, + inner_after_hook=inner_after_hook, + ) # Use generic factory wrapping method self._wrap_factory_for_phase( diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/internal/_wrapper.py b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/internal/_wrapper.py index a1dab4be..aacf3027 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/internal/_wrapper.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/internal/_wrapper.py @@ -26,7 +26,21 @@ SemanticAttributes, SpanName, ) -from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer +from opentelemetry.instrumentation.mem0.types import ( + HookContext, + InnerAfterHook, + InnerBeforeHook, + MemoryAfterHook, + MemoryBeforeHook, + safe_call_hook, +) +from opentelemetry.trace import ( + SpanKind, + Status, + StatusCode, + Tracer, + get_current_span, +) from opentelemetry.util.genai._extended_memory import MemoryInvocation from opentelemetry.util.genai.extended_handler import ExtendedTelemetryHandler from opentelemetry.util.genai.types import Error @@ -149,6 +163,8 @@ def __init__(self, telemetry_handler: ExtendedTelemetryHandler): """ self.telemetry_handler = telemetry_handler self.extractor = MemoryOperationAttributeExtractor() + self._memory_before_hook: MemoryBeforeHook = None + self._memory_after_hook: MemoryAfterHook = None def wrap_operation( self, @@ -444,6 +460,18 @@ def _execute_with_handler( ) self.telemetry_handler.start_memory(invocation) + hook_context: HookContext = {} + # Read current span after util handler starts memory (span should exist in most cases) + span = get_current_span() + safe_call_hook( + self._memory_before_hook, + span, + operation_name, + instance, + args, + dict(kwargs), + hook_context, + ) try: result = func(*args, **kwargs) # Post-extract result attributes/content (must happen before stop_memory) @@ -456,9 +484,31 @@ def _execute_with_handler( extract_attributes_func=extract_attributes_func, is_memory_client=is_memory_client, ) + safe_call_hook( + self._memory_after_hook, + span, + operation_name, + instance, + args, + dict(kwargs), + hook_context, + result, + None, + ) self.telemetry_handler.stop_memory(invocation) return result except Exception as e: + safe_call_hook( + self._memory_after_hook, + span, + operation_name, + instance, + args, + dict(kwargs), + hook_context, + None, + e, + ) self.telemetry_handler.fail_memory( invocation, Error(message=str(e), type=type(e)) ) @@ -502,6 +552,17 @@ async def _execute_with_handler_async( ) self.telemetry_handler.start_memory(invocation) + hook_context: HookContext = {} + span = get_current_span() + safe_call_hook( + self._memory_before_hook, + span, + operation_name, + instance, + args, + dict(kwargs), + hook_context, + ) try: result = await func(*args, **kwargs) # Post-extract result attributes/content (must happen before stop_memory) @@ -514,9 +575,31 @@ async def _execute_with_handler_async( extract_attributes_func=extract_attributes_func, is_memory_client=is_memory_client, ) + safe_call_hook( + self._memory_after_hook, + span, + operation_name, + instance, + args, + dict(kwargs), + hook_context, + result, + None, + ) self.telemetry_handler.stop_memory(invocation) return result except Exception as e: + safe_call_hook( + self._memory_after_hook, + span, + operation_name, + instance, + args, + dict(kwargs), + hook_context, + None, + e, + ) self.telemetry_handler.fail_memory( invocation, Error(message=str(e), type=type(e)) ) @@ -526,7 +609,13 @@ async def _execute_with_handler_async( class VectorStoreWrapper: """Vector store subphase wrapper.""" - def __init__(self, tracer: Tracer): + def __init__( + self, + tracer: Tracer, + *, + inner_before_hook: InnerBeforeHook = None, + inner_after_hook: InnerAfterHook = None, + ): """ Initialize wrapper. @@ -535,6 +624,8 @@ def __init__(self, tracer: Tracer): """ self.tracer = tracer self.extractor = VectorOperationAttributeExtractor() + self._inner_before_hook = inner_before_hook + self._inner_after_hook = inner_after_hook def wrap_vector_operation(self, method_name: str) -> Callable: """ @@ -566,6 +657,17 @@ def wrapper( kind=SpanKind.CLIENT, ) as span: result = None + hook_context: HookContext = {} + safe_call_hook( + self._inner_before_hook, + span, + "vector", + method_name, + instance, + args, + dict(kwargs), + hook_context, + ) # Store extracted attributes (defined outside try for finally access) span_attrs = {} @@ -582,6 +684,18 @@ def wrapper( span.set_attribute(key, value) span.set_status(Status(StatusCode.OK)) + safe_call_hook( + self._inner_after_hook, + span, + "vector", + method_name, + instance, + args, + dict(kwargs), + hook_context, + result, + None, + ) return result except Exception as e: @@ -590,6 +704,18 @@ def wrapper( span.set_attribute( SemanticAttributes.ERROR_TYPE, get_exception_type(e) ) + safe_call_hook( + self._inner_after_hook, + span, + "vector", + method_name, + instance, + args, + dict(kwargs), + hook_context, + None, + e, + ) raise return wrapper @@ -602,7 +728,13 @@ def _get_span_name(self, method_name: str) -> str: class GraphStoreWrapper: """Graph store subphase wrapper.""" - def __init__(self, tracer: Tracer): + def __init__( + self, + tracer: Tracer, + *, + inner_before_hook: InnerBeforeHook = None, + inner_after_hook: InnerAfterHook = None, + ): """ Initialize wrapper. @@ -611,6 +743,8 @@ def __init__(self, tracer: Tracer): """ self.tracer = tracer self.extractor = GraphOperationAttributeExtractor() + self._inner_before_hook = inner_before_hook + self._inner_after_hook = inner_after_hook def wrap_graph_operation(self, method_name: str) -> Callable: """ @@ -638,6 +772,17 @@ def wrapper( kind=SpanKind.CLIENT, ) as span: result = None + hook_context: HookContext = {} + safe_call_hook( + self._inner_before_hook, + span, + "graph", + method_name, + instance, + args, + dict(kwargs), + hook_context, + ) # Store extracted attributes (defined outside try for finally access) span_attrs = {} @@ -654,6 +799,18 @@ def wrapper( span.set_attribute(key, value) span.set_status(Status(StatusCode.OK)) + safe_call_hook( + self._inner_after_hook, + span, + "graph", + method_name, + instance, + args, + dict(kwargs), + hook_context, + result, + None, + ) return result except Exception as e: @@ -662,6 +819,18 @@ def wrapper( span.set_attribute( SemanticAttributes.ERROR_TYPE, get_exception_type(e) ) + safe_call_hook( + self._inner_after_hook, + span, + "graph", + method_name, + instance, + args, + dict(kwargs), + hook_context, + None, + e, + ) raise return wrapper @@ -674,7 +843,13 @@ def _get_span_name(self, method_name: str) -> str: class RerankerWrapper: """Reranker subphase wrapper.""" - def __init__(self, tracer: Tracer): + def __init__( + self, + tracer: Tracer, + *, + inner_before_hook: InnerBeforeHook = None, + inner_after_hook: InnerAfterHook = None, + ): """ Initialize wrapper. @@ -683,6 +858,8 @@ def __init__(self, tracer: Tracer): """ self.tracer = tracer self.extractor = RerankerAttributeExtractor() + self._inner_before_hook = inner_before_hook + self._inner_after_hook = inner_after_hook def wrap_rerank(self) -> Callable: """ @@ -713,6 +890,17 @@ def wrapper( SpanName.get_subphase_span_name("reranker", "rerank"), kind=SpanKind.CLIENT, ) as span: + hook_context: HookContext = {} + safe_call_hook( + self._inner_before_hook, + span, + "rerank", + "rerank", + instance, + args, + dict(kwargs), + hook_context, + ) # Store extracted attributes (defined outside try for finally access) span_attrs = {} @@ -728,6 +916,18 @@ def wrapper( result = wrapped(*args, **kwargs) span.set_status(Status(StatusCode.OK)) + safe_call_hook( + self._inner_after_hook, + span, + "rerank", + "rerank", + instance, + args, + dict(kwargs), + hook_context, + result, + None, + ) return result except Exception as e: @@ -736,6 +936,18 @@ def wrapper( span.set_attribute( SemanticAttributes.ERROR_TYPE, get_exception_type(e) ) + safe_call_hook( + self._inner_after_hook, + span, + "rerank", + "rerank", + instance, + args, + dict(kwargs), + hook_context, + None, + e, + ) raise return wrapper diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/types.py b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/types.py new file mode 100644 index 00000000..72305ccc --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/types.py @@ -0,0 +1,45 @@ +""" +Mem0 instrumentation public hook types and helpers. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Optional + +logger = logging.getLogger(__name__) + +# Per-call hook context. Instrumentation only creates and passes it through. +HookContext = Dict[str, Any] + +# Hook callables are kept intentionally loose: the open-source package only passes through +# parameters, and commercial extensions are responsible for extracting/recording data. +MemoryBeforeHook = Optional[Callable[..., Any]] +MemoryAfterHook = Optional[Callable[..., Any]] +InnerBeforeHook = Optional[Callable[..., Any]] +InnerAfterHook = Optional[Callable[..., Any]] + + +def safe_call_hook(hook: Optional[Callable[..., Any]], *args: Any) -> None: + """ + Call a hook defensively: swallow hook exceptions to avoid breaking user code. + """ + if not callable(hook): + return + try: + hook(*args) + except Exception as e: + logger.debug("mem0 hook raised and was swallowed: %s", e) + + +def set_memory_hooks( + wrapper: Any, + *, + memory_before_hook: MemoryBeforeHook = None, + memory_after_hook: MemoryAfterHook = None, +) -> None: + """ + Configure top-level memory hooks on a MemoryOperationWrapper instance. + """ + wrapper._memory_before_hook = memory_before_hook + wrapper._memory_after_hook = memory_after_hook diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_extractors.py b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_extractors.py index 0f9c1edd..ae93f1a0 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_extractors.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_extractors.py @@ -347,6 +347,67 @@ def test_extract_invocation_attributes(self): # update/batch_update input message coverage merged into test_extract_invocation_content_input_messages + def test_compact_exception_branches(self): + """ + Compact coverage test to hit rarely-triggered exception branches: + - _extract_input_content exception path + - batch size extraction exception path + - result_count extraction exception path + - op-specific extractor exception path + """ + + # 1) _extract_input_content exception: kwargs.get raises + class BadKwargs(dict): + def get(self, *a, **k): # type: ignore[override] + raise RuntimeError("boom") + + input_msg, _ = self.extractor.extract_invocation_content( + "search", BadKwargs(), None + ) + self.assertIsNone(input_msg) + + # 2) batch size extraction exception: list type but __len__ raises + class BadList(list): + def __len__(self): # pragma: no cover + raise RuntimeError("boom") + + # 3) result_count extraction exception + 4) op-specific extractor exception + from unittest.mock import patch # noqa: PLC0415 + + def bad_specific( + kwargs: Dict[str, Any], result: Any + ) -> Dict[str, Any]: + raise RuntimeError("boom") + + setattr( + MemoryOperationAttributeExtractor, + "extract_search_attributes", + staticmethod(bad_specific), + ) + try: + with patch( + "opentelemetry.instrumentation.mem0.internal._extractors.extract_result_count", + side_effect=RuntimeError("boom"), + ): + attrs = self.extractor.extract_invocation_attributes( + "search", + { + "memories": BadList([{"id": 1}]), + "infer": True, + "metadata": {"k": "v"}, + "filters": {"f": 1}, + "fields": ["a"], + "categories": ["c"], + }, + result={"results": [1, 2, 3]}, + ) + self.assertIsInstance(attrs, dict) + finally: + # cleanup dynamic method + delattr( + MemoryOperationAttributeExtractor, "extract_search_attributes" + ) + class TestVectorOperationAttributeExtractor(unittest.TestCase): """Tests Vector operation attribute extractor""" diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_hooks.py b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_hooks.py new file mode 100644 index 00000000..f0199228 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_hooks.py @@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +# pyright: ignore +""" +Tests for Mem0 hooks (before/after) plumbing. + +These tests validate that: +- Hooks are called for top-level and inner operations +- hook_context is shared between before/after for a single call +- Hook exceptions are swallowed and do not break the wrapped call +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple + +import pytest + +from opentelemetry.instrumentation.mem0.internal import _wrapper as wrapper_mod +from opentelemetry.instrumentation.mem0.internal._wrapper import ( + GraphStoreWrapper, + MemoryOperationWrapper, + RerankerWrapper, + VectorStoreWrapper, +) +from opentelemetry.instrumentation.mem0.types import set_memory_hooks +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.util.genai._extended_memory import MemoryInvocation + + +class _DummyTelemetryHandler: + def start_memory(self, invocation: Any) -> None: + return None + + def stop_memory(self, invocation: Any) -> None: + return None + + def fail_memory(self, invocation: Any, error: Any) -> None: + return None + + +@pytest.mark.asyncio +async def test_memory_hooks_sync_async_and_exception_paths(): + calls: List[Tuple[str, Dict[str, Any]]] = [] + seen_exc: Dict[str, Any] = {"exc": None} + + def before(span, operation, instance, args, kwargs, hook_context): + hook_context["start_time"] = 1 + calls.append(("before", hook_context)) + + def after( + span, operation, instance, args, kwargs, hook_context, result, exc + ): + # Validate context is shared and exception is surfaced + assert hook_context.get("start_time") == 1 + if exc is not None: + seen_exc["exc"] = exc + calls.append(("after", hook_context)) + + w = MemoryOperationWrapper(_DummyTelemetryHandler()) + set_memory_hooks(w, memory_before_hook=before, memory_after_hook=after) + + # Cover helper: _normalize_call_parameters with positional args mapping + def _sig(self, memory_id, data, *, user_id=None): # noqa: ARG001 + return None + + normalized = wrapper_mod._normalize_call_parameters( # type: ignore[attr-defined] + _sig, args=("mid", "payload"), kwargs={"user_id": "u1"} + ) + assert normalized["memory_id"] == "mid" + assert normalized["data"] == "payload" + assert normalized["user_id"] == "u1" + + # Cover helper: _apply_custom_extractor_output_to_invocation and leftover attribute mapping + inv = MemoryInvocation(operation="add") + MemoryOperationWrapper._apply_custom_extractor_output_to_invocation( + inv, + { + "user_id": "u1", + "limit": "3", + "threshold": "0.7", + "rerank": False, + "server_address": "example.com", + "server_port": "443", + "attributes": {"k1": "v1"}, + "custom_k": "custom_v", + }, + ) + assert inv.user_id == "u1" + assert inv.limit == 3 + assert inv.threshold == 0.7 + assert inv.rerank is False + assert inv.server_address == "example.com" + assert inv.server_port == 443 + assert inv.attributes.get("k1") == "v1" + assert inv.attributes.get("custom_k") == "custom_v" + + # sync success + def _fn_sync(*a, **k): + return "ok" + + res_sync = w._execute_with_handler( + _fn_sync, + instance=object(), + args=(), + kwargs={"k": "v"}, + operation_name="add", + extract_attributes_func=None, + is_memory_client=False, + ) + assert res_sync == "ok" + + # cover extract_server_info exception branch (host property raises) + class _BadHost: + @property + def host(self): # pragma: no cover + raise RuntimeError("boom") + + _ = w._execute_with_handler( + _fn_sync, + instance=_BadHost(), + args=(), + kwargs={}, + operation_name="get_all", + extract_attributes_func=None, + is_memory_client=True, + ) + + # cover custom extractor function path (extract_attributes_func provided) + inv2 = MemoryInvocation(operation="add") + w._apply_extracted_attrs_to_invocation( + inv2, + instance=object(), + normalized_kwargs={}, + operation_name="add", + result={"results": []}, + extract_attributes_func=lambda kwargs, result: { + "user_id": "u2", + "attributes": {"x": 1}, + }, + is_memory_client=False, + ) + assert inv2.user_id == "u2" + assert inv2.attributes.get("x") == 1 + + # async success + async def _fn_async(*a, **k): + return "ok2" + + res_async = await w._execute_with_handler_async( + _fn_async, + instance=object(), + args=(), + kwargs={"k": "v"}, + operation_name="search", + extract_attributes_func=None, + is_memory_client=False, + ) + assert res_async == "ok2" + + # hook exceptions swallowed (both before/after) + def before_boom(span, operation, instance, args, kwargs, hook_context): + raise RuntimeError("boom") + + def after_boom( + span, operation, instance, args, kwargs, hook_context, result, exc + ): + raise RuntimeError("boom2") + + set_memory_hooks( + w, memory_before_hook=before_boom, memory_after_hook=after_boom + ) + assert ( + w._execute_with_handler( + _fn_sync, + instance=object(), + args=(), + kwargs={}, + operation_name="get", + extract_attributes_func=None, + is_memory_client=False, + ) + == "ok" + ) + + # exception path calls after_hook with exception + set_memory_hooks(w, memory_before_hook=before, memory_after_hook=after) + + def _fn_raises(*a, **k): + raise ValueError("nope") + + with pytest.raises(ValueError): + w._execute_with_handler( + _fn_raises, + instance=object(), + args=(), + kwargs={}, + operation_name="delete", + extract_attributes_func=None, + is_memory_client=False, + ) + + assert isinstance(seen_exc["exc"], ValueError) + # For at least the first sync call, context is shared between before/after + assert calls[0][1] is calls[1][1] + + +def test_inner_hooks_vector_graph_rerank_and_exception_paths(monkeypatch): + # Cover both disabled and enabled internal phases paths + monkeypatch.setattr( + wrapper_mod, "is_internal_phases_enabled", lambda: False + ) + tracer = TracerProvider().get_tracer("test") + vw_disabled = VectorStoreWrapper(tracer) + # when disabled, wrapper returns original result without span/hook + assert vw_disabled.wrap_vector_operation("search")( + lambda *a, **k: {"ok": True}, object(), (), {"limit": 1} + ) == {"ok": True} + + # Enable internal phases for hook paths below + monkeypatch.setattr( + wrapper_mod, "is_internal_phases_enabled", lambda: True + ) + + calls: List[Tuple[str, str, str, Dict[str, Any]]] = [] + seen_exc: Dict[str, Any] = {"exc": None} + + def before( + span, inner_name, operation, instance, args, kwargs, hook_context + ): + hook_context["start_time"] = 1 + calls.append(("before", inner_name, operation, hook_context)) + + def after( + span, + inner_name, + operation, + instance, + args, + kwargs, + hook_context, + result, + exc, + ): + assert hook_context.get("start_time") == 1 + if exc is not None: + seen_exc["exc"] = exc + calls.append(("after", inner_name, operation, hook_context)) + + vw = VectorStoreWrapper( + tracer, inner_before_hook=before, inner_after_hook=after + ) + gw = GraphStoreWrapper( + tracer, inner_before_hook=before, inner_after_hook=after + ) + rw = RerankerWrapper( + tracer, inner_before_hook=before, inner_after_hook=after + ) + + def wrapped(*a, **k): + return {"results": [1]} + + # vector success + res = vw.wrap_vector_operation("search")( + wrapped, object(), (), {"limit": 1} + ) + assert res == {"results": [1]} + + # Cover skip branch: mem0migrations collection should bypass span/hook + class _Migrations: + collection_name = "mem0migrations" + + calls_before = len(calls) + assert vw.wrap_vector_operation("search")( + wrapped, _Migrations(), (), {"limit": 1} + ) == {"results": [1]} + assert len(calls) == calls_before # no hooks recorded + + def graph_wrapped(*a, **k): + return {"nodes": [1]} + + def rerank_wrapped(*a, **k): + return [{"id": 1}] + + # graph success + assert gw.wrap_graph_operation("add")(graph_wrapped, object(), (), {}) == { + "nodes": [1] + } + # rerank success + assert rw.wrap_rerank()( + rerank_wrapped, object(), (), {"query": "q", "documents": []} + ) == [{"id": 1}] + + # after hook exception swallowed (vector) + def after_boom( + span, + inner_name, + operation, + instance, + args, + kwargs, + hook_context, + result, + exc, + ): + raise RuntimeError("boom") + + vw2 = VectorStoreWrapper( + tracer, inner_before_hook=before, inner_after_hook=after_boom + ) + assert vw2.wrap_vector_operation("search")( + wrapped, object(), (), {"limit": 1} + ) == {"results": [1]} + + # after hook receives exception (vector) + def wrapped_raises(*a, **k): + raise ValueError("nope") + + with pytest.raises(ValueError): + vw.wrap_vector_operation("search")( + wrapped_raises, object(), (), {"limit": 1} + ) + assert isinstance(seen_exc["exc"], ValueError) + + # graph exception path + def graph_raises(*a, **k): + raise ValueError("nope") + + with pytest.raises(ValueError): + gw.wrap_graph_operation("add")(graph_raises, object(), (), {}) + + # rerank exception path + def rerank_raises(*a, **k): + raise ValueError("nope") + + with pytest.raises(ValueError): + rw.wrap_rerank()( + rerank_raises, object(), (), {"query": "q", "documents": []} + ) + + # Sanity: at least one before/after pair shares context object (vector success) + assert calls[0][0] == "before" and calls[1][0] == "after" + assert calls[0][3] is calls[1][3] diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_instrumentor.py b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_instrumentor.py index 9ca126fb..23828a2d 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_instrumentor.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_instrumentor.py @@ -297,6 +297,292 @@ def _private(self): finally: sys.modules.pop("test_module", None) + @patch( + "opentelemetry.instrumentation.mem0.config.is_internal_phases_enabled" + ) + def test_branchy_paths_compact(self, mock_internal_enabled): + """ + Compact branch-coverage test: + - early-return paths in _instrument/_uninstrument + - _public_methods_of import failure + - _public_methods_of_cls attribute access failure + - _wrap_factory_for_phase: check_enabled short-circuit + one-time wrapping + __otel_mem0_original_config__ + """ + inst = Mem0Instrumentor() + + # _instrument early return + inst._is_instrumented = True + inst._instrument() + inst._is_instrumented = False + + # _uninstrument early return + inst._is_instrumented = False + inst._uninstrument() + + # _public_methods_of import error path + with patch("builtins.__import__", side_effect=ImportError("nope")): + self.assertEqual(inst._public_methods_of("x.y", "Z"), []) + + # _public_methods_of_cls getattr error path + class Weird: + @property + def bad(self): # pragma: no cover + raise RuntimeError("boom") + + def ok(self): + return 1 + + methods = inst._public_methods_of_cls(Weird) + self.assertIn("ok", methods) + self.assertNotIn("bad", methods) + + # _unwrap_class_methods: inner unwrap failure + outer exception path + with patch( + "opentelemetry.instrumentation.mem0.unwrap", + side_effect=RuntimeError("unwrap boom"), + ): + with patch.object( + inst, "_public_methods_of_cls", return_value=["add"] + ): + # allowed method should attempt unwrap and swallow exception + inst._unwrap_class_methods(Weird, "Weird") + with patch.object( + inst, + "_public_methods_of_cls", + side_effect=RuntimeError("boom"), + ): + # outer exception is swallowed too + inst._unwrap_class_methods(Weird, "Weird") + + # _unwrap_factory: early return + unwrap exception swallowing + with patch( + "opentelemetry.instrumentation.mem0._FACTORIES_AVAILABLE", False + ): + inst._unwrap_factory(object, "X") # no-op + with patch( + "opentelemetry.instrumentation.mem0._FACTORIES_AVAILABLE", True + ): + with patch( + "opentelemetry.instrumentation.mem0.unwrap", + side_effect=RuntimeError("unwrap boom"), + ): + inst._unwrap_factory(type("F", (), {}), "F") + + # _get_base_methods: exception path fallback + class BadMeta(type): + def __getattribute__(cls, name): # noqa: N805 + if name == "__dict__": + raise RuntimeError("no dict") + return super().__getattribute__(name) + + class BadBase(metaclass=BadMeta): + pass + + defaults = ["a", "b"] + self.assertEqual(inst._get_base_methods(None, "X", defaults), defaults) + self.assertEqual( + inst._get_base_methods(BadBase, "BadBase", defaults), defaults + ) + + # _unwrap_dynamic_classes: import error + unwrap error paths + inst._instrumented_vector_classes.add("nonexistent.mod.ClassX") + with patch( + "opentelemetry.instrumentation.mem0.unwrap", + side_effect=RuntimeError("unwrap boom"), + ): + inst._unwrap_dynamic_classes( + inst._instrumented_vector_classes, ["search"] + ) + self.assertEqual(len(inst._instrumented_vector_classes), 0) + + # _wrap_factory_for_phase: capture factory wrapper closure + captured = {} + + def capture_factory_wrapper(module, name, wrapper): + captured["wrapper"] = wrapper + + with patch( + "opentelemetry.instrumentation.mem0.wrap_function_wrapper", + side_effect=capture_factory_wrapper, + ): + inst._wrap_factory_for_phase( + factory_module="mem0.utils.factory", + factory_class="VectorStoreFactory", + phase_name="vector", + methods=["search"], + wrapper_instance=type( + "W", + (), + { + "wrap_vector_operation": lambda self, m: ( + lambda *a, **k: None + ) + }, + )(), + instrumented_classes_set=set(), + check_enabled_func=lambda: False, + ) + + # Disabled: should return wrapped() directly + dummy = object() + + def wrapped(*a, **k): + return dummy + + self.assertIs( + captured["wrapper"](wrapped, None, ("p", {"url": "x"}), {}), + dummy, + ) + + # Enabled: should set __otel_mem0_original_config__ and wrap method once + wrapped_calls = [] + captured2 = {} + + def capture_factory_wrapper2(module, name, wrapper): + captured2["wrapper"] = wrapper + + def record_wrap(module, name, wrapper): + wrapped_calls.append((module, name)) + + with patch( + "opentelemetry.instrumentation.mem0.wrap_function_wrapper", + side_effect=capture_factory_wrapper2, + ): + inst._wrap_factory_for_phase( + factory_module="mem0.utils.factory", + factory_class="VectorStoreFactory", + phase_name="vector", + methods=["search"], + wrapper_instance=type( + "W", + (), + { + "wrap_vector_operation": lambda self, m: ( + lambda *a, **k: None + ) + }, + )(), + instrumented_classes_set=set(), + check_enabled_func=lambda: True, + ) + + class DummyVector: + def search(self, **kwargs): + return {"ok": True} + + def factory_create(provider, config): + return DummyVector() + + with patch( + "opentelemetry.instrumentation.mem0.wrap_function_wrapper", + side_effect=record_wrap, + ): + obj = captured2["wrapper"]( + factory_create, None, ("p", {"url": "x"}), {} + ) + self.assertTrue(hasattr(obj, "__otel_mem0_original_config__")) + self.assertTrue( + any(n.endswith("DummyVector.search") for _, n in wrapped_calls) + ) + + # Second time should not wrap again for same fqcn + wrapped_calls.clear() + _ = captured2["wrapper"]( + factory_create, None, ("p", {"url": "y"}), {} + ) + self.assertEqual(wrapped_calls, []) + + # Ensure we touched internal-phase gate path at least once + mock_internal_enabled.return_value = False + + # _wrap_factory_for_phase: unknown phase branch + wrap_function_wrapper failure paths + captured3 = {} + + def capture_factory_wrapper3(module, name, wrapper): + captured3["wrapper"] = wrapper + + with patch( + "opentelemetry.instrumentation.mem0.wrap_function_wrapper", + side_effect=capture_factory_wrapper3, + ): + inst._wrap_factory_for_phase( + factory_module="mem0.utils.factory", + factory_class="VectorStoreFactory", + phase_name="unknown", + methods=["search"], + wrapper_instance=type("W", (), {})(), + instrumented_classes_set=set(), + check_enabled_func=lambda: True, + ) + + # When phase_name is unknown, it should just skip wrapping methods. + class DummyUnknown: + def search(self, **kwargs): + return {"ok": True} + + def factory_create_unknown(provider, config): + return DummyUnknown() + + _ = captured3["wrapper"]( + factory_create_unknown, None, ("p", {"url": "x"}), {} + ) + + # Outer wrap_function_wrapper failure when wrapping create + with patch( + "opentelemetry.instrumentation.mem0.wrap_function_wrapper", + side_effect=RuntimeError("wrap boom"), + ): + inst._wrap_factory_for_phase( + factory_module="mem0.utils.factory", + factory_class="VectorStoreFactory", + phase_name="vector", + methods=["search"], + wrapper_instance=type( + "W", + (), + { + "wrap_vector_operation": lambda self, m: ( + lambda *a, **k: None + ) + }, + )(), + instrumented_classes_set=set(), + check_enabled_func=lambda: True, + ) + + # _instrument_* skip paths (types/factories unavailable) and method list fallbacks + with patch.multiple( + "opentelemetry.instrumentation.mem0", + _MEM0_CORE_AVAILABLE=False, + _FACTORIES_AVAILABLE=False, + VectorStoreBase=None, + VectorStoreFactory=None, + GraphStoreFactory=None, + RerankerFactory=None, + ): + inst._instrument_vector_operations(object()) + inst._instrument_graph_operations(object()) + inst._instrument_reranker_operations(object()) + + # method list fallback branches (VectorStoreBase.__dict__ and MemoryGraph.__dict__ raising) + with patch.multiple( + "opentelemetry.instrumentation.mem0", + _MEM0_CORE_AVAILABLE=True, + _FACTORIES_AVAILABLE=True, + VectorStoreFactory=type("VSF", (), {}), + GraphStoreFactory=type("GSF", (), {}), + RerankerFactory=type("RRF", (), {}), + VectorStoreBase=BadBase, + _MEMORY_GRAPH_AVAILABLE=True, + MemoryGraph=BadBase, + ): + with patch.object( + inst, "_wrap_factory_for_phase", return_value=None + ): + inst._instrument_vector_operations(object()) + inst._instrument_graph_operations(object()) + inst._instrument_reranker_operations(object()) + if __name__ == "__main__": unittest.main()