Skip to content

Commit 3ae8446

Browse files
committed
feat:mem0 add hook extension
Change-Id: I9696dfbe580630ed6a317e851882bc652385ac1d
1 parent 075339d commit 3ae8446

File tree

5 files changed

+83
-60
lines changed

5 files changed

+83
-60
lines changed

CHANGELOG-loongsuite.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Fixed
1515

16-
- `loongsuite-instrumentation-mem0`: add hook extension
17-
([#95](https://github.com/alibaba/loongsuite-python-agent/pull/95))
18-
1916
- `loongsuite-instrumentation-mem0`: use memory handler
2017
([#89](https://github.com/alibaba/loongsuite-python-agent/pull/89))
2118

@@ -24,5 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2421

2522
# Added
2623

24+
- `loongsuite-instrumentation-mem0`: add hook extension
25+
([#95](https://github.com/alibaba/loongsuite-python-agent/pull/95))
26+
2727
- `loongsuite-instrumentation-mem0`: add support for mem0
2828
([#67](https://github.com/alibaba/loongsuite-python-agent/pull/67))

instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
VectorStoreWrapper,
2424
)
2525
from opentelemetry.instrumentation.mem0.package import _instruments
26+
from opentelemetry.instrumentation.mem0.types import set_memory_hooks
2627
from opentelemetry.instrumentation.mem0.version import __version__
2728
from opentelemetry.instrumentation.utils import unwrap
2829
from opentelemetry.semconv.schemas import Schemas
@@ -439,7 +440,8 @@ def _instrument_memory_operations(
439440
return
440441

441442
wrapper = MemoryOperationWrapper(telemetry_handler)
442-
wrapper.set_hooks(
443+
set_memory_hooks(
444+
wrapper,
443445
memory_before_hook=memory_before_hook,
444446
memory_after_hook=memory_after_hook,
445447
)
@@ -501,7 +503,8 @@ def _instrument_memory_client_operations(
501503
return
502504

503505
wrapper = MemoryOperationWrapper(telemetry_handler)
504-
wrapper.set_hooks(
506+
set_memory_hooks(
507+
wrapper,
505508
memory_before_hook=memory_before_hook,
506509
memory_after_hook=memory_after_hook,
507510
)

instrumentation-loongsuite/loongsuite-instrumentation-mem0/src/opentelemetry/instrumentation/mem0/internal/_wrapper.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import inspect
77
import logging
88
from functools import wraps
9-
from typing import Any, Callable, Dict, Optional
9+
from typing import Any, Callable, Optional
1010

1111
from opentelemetry.instrumentation.mem0.config import (
1212
is_internal_phases_enabled,
@@ -26,6 +26,14 @@
2626
SemanticAttributes,
2727
SpanName,
2828
)
29+
from opentelemetry.instrumentation.mem0.types import (
30+
HookContext,
31+
InnerAfterHook,
32+
InnerBeforeHook,
33+
MemoryAfterHook,
34+
MemoryBeforeHook,
35+
safe_call_hook,
36+
)
2937
from opentelemetry.trace import (
3038
SpanKind,
3139
Status,
@@ -39,28 +47,6 @@
3947

4048
logger = logging.getLogger(__name__)
4149

42-
# Per-call hook context. Instrumentation only creates and passes it through.
43-
HookContext = Dict[str, Any]
44-
45-
# Hook types are intentionally kept loose here to avoid coupling this package's runtime
46-
# to optional extension modules / type-checking configuration. Hooks are pure pass-through.
47-
MemoryBeforeHook = Optional[Callable[..., Any]]
48-
MemoryAfterHook = Optional[Callable[..., Any]]
49-
InnerBeforeHook = Optional[Callable[..., Any]]
50-
InnerAfterHook = Optional[Callable[..., Any]]
51-
52-
53-
def _safe_call_hook(hook: Optional[Callable], *args: Any) -> None:
54-
"""
55-
Call a hook defensively: swallow hook exceptions to avoid breaking user code.
56-
"""
57-
if not callable(hook):
58-
return
59-
try:
60-
hook(*args)
61-
except Exception as e: # pragma: no cover - defensive
62-
logger.debug("mem0 hook raised and was swallowed: %s", e)
63-
6450

6551
def _get_field(payload: dict, field_name: str) -> Any:
6652
"""
@@ -180,20 +166,6 @@ def __init__(self, telemetry_handler: ExtendedTelemetryHandler):
180166
self._memory_before_hook: MemoryBeforeHook = None
181167
self._memory_after_hook: MemoryAfterHook = None
182168

183-
def set_hooks(
184-
self,
185-
*,
186-
memory_before_hook: MemoryBeforeHook = None,
187-
memory_after_hook: MemoryAfterHook = None,
188-
) -> None:
189-
"""
190-
Set optional hooks for top-level memory operations.
191-
192-
Hooks are stored on the wrapper instance to avoid changing wrapt wrapper signatures.
193-
"""
194-
self._memory_before_hook = memory_before_hook
195-
self._memory_after_hook = memory_after_hook
196-
197169
def wrap_operation(
198170
self,
199171
operation_name: str,
@@ -491,7 +463,7 @@ def _execute_with_handler(
491463
hook_context: HookContext = {}
492464
# Read current span after util handler starts memory (span should exist in most cases)
493465
span = get_current_span()
494-
_safe_call_hook(
466+
safe_call_hook(
495467
self._memory_before_hook,
496468
span,
497469
operation_name,
@@ -512,7 +484,7 @@ def _execute_with_handler(
512484
extract_attributes_func=extract_attributes_func,
513485
is_memory_client=is_memory_client,
514486
)
515-
_safe_call_hook(
487+
safe_call_hook(
516488
self._memory_after_hook,
517489
span,
518490
operation_name,
@@ -526,7 +498,7 @@ def _execute_with_handler(
526498
self.telemetry_handler.stop_memory(invocation)
527499
return result
528500
except Exception as e:
529-
_safe_call_hook(
501+
safe_call_hook(
530502
self._memory_after_hook,
531503
span,
532504
operation_name,
@@ -582,7 +554,7 @@ async def _execute_with_handler_async(
582554
self.telemetry_handler.start_memory(invocation)
583555
hook_context: HookContext = {}
584556
span = get_current_span()
585-
_safe_call_hook(
557+
safe_call_hook(
586558
self._memory_before_hook,
587559
span,
588560
operation_name,
@@ -603,7 +575,7 @@ async def _execute_with_handler_async(
603575
extract_attributes_func=extract_attributes_func,
604576
is_memory_client=is_memory_client,
605577
)
606-
_safe_call_hook(
578+
safe_call_hook(
607579
self._memory_after_hook,
608580
span,
609581
operation_name,
@@ -617,7 +589,7 @@ async def _execute_with_handler_async(
617589
self.telemetry_handler.stop_memory(invocation)
618590
return result
619591
except Exception as e:
620-
_safe_call_hook(
592+
safe_call_hook(
621593
self._memory_after_hook,
622594
span,
623595
operation_name,
@@ -686,7 +658,7 @@ def wrapper(
686658
) as span:
687659
result = None
688660
hook_context: HookContext = {}
689-
_safe_call_hook(
661+
safe_call_hook(
690662
self._inner_before_hook,
691663
span,
692664
"vector",
@@ -712,7 +684,7 @@ def wrapper(
712684
span.set_attribute(key, value)
713685

714686
span.set_status(Status(StatusCode.OK))
715-
_safe_call_hook(
687+
safe_call_hook(
716688
self._inner_after_hook,
717689
span,
718690
"vector",
@@ -732,7 +704,7 @@ def wrapper(
732704
span.set_attribute(
733705
SemanticAttributes.ERROR_TYPE, get_exception_type(e)
734706
)
735-
_safe_call_hook(
707+
safe_call_hook(
736708
self._inner_after_hook,
737709
span,
738710
"vector",
@@ -801,7 +773,7 @@ def wrapper(
801773
) as span:
802774
result = None
803775
hook_context: HookContext = {}
804-
_safe_call_hook(
776+
safe_call_hook(
805777
self._inner_before_hook,
806778
span,
807779
"graph",
@@ -827,7 +799,7 @@ def wrapper(
827799
span.set_attribute(key, value)
828800

829801
span.set_status(Status(StatusCode.OK))
830-
_safe_call_hook(
802+
safe_call_hook(
831803
self._inner_after_hook,
832804
span,
833805
"graph",
@@ -847,7 +819,7 @@ def wrapper(
847819
span.set_attribute(
848820
SemanticAttributes.ERROR_TYPE, get_exception_type(e)
849821
)
850-
_safe_call_hook(
822+
safe_call_hook(
851823
self._inner_after_hook,
852824
span,
853825
"graph",
@@ -919,7 +891,7 @@ def wrapper(
919891
kind=SpanKind.CLIENT,
920892
) as span:
921893
hook_context: HookContext = {}
922-
_safe_call_hook(
894+
safe_call_hook(
923895
self._inner_before_hook,
924896
span,
925897
"rerank",
@@ -944,7 +916,7 @@ def wrapper(
944916
result = wrapped(*args, **kwargs)
945917

946918
span.set_status(Status(StatusCode.OK))
947-
_safe_call_hook(
919+
safe_call_hook(
948920
self._inner_after_hook,
949921
span,
950922
"rerank",
@@ -964,7 +936,7 @@ def wrapper(
964936
span.set_attribute(
965937
SemanticAttributes.ERROR_TYPE, get_exception_type(e)
966938
)
967-
_safe_call_hook(
939+
safe_call_hook(
968940
self._inner_after_hook,
969941
span,
970942
"rerank",
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Mem0 instrumentation public hook types and helpers.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import logging
8+
from typing import Any, Callable, Dict, Optional
9+
10+
logger = logging.getLogger(__name__)
11+
12+
# Per-call hook context. Instrumentation only creates and passes it through.
13+
HookContext = Dict[str, Any]
14+
15+
# Hook callables are kept intentionally loose: the open-source package only passes through
16+
# parameters, and commercial extensions are responsible for extracting/recording data.
17+
MemoryBeforeHook = Optional[Callable[..., Any]]
18+
MemoryAfterHook = Optional[Callable[..., Any]]
19+
InnerBeforeHook = Optional[Callable[..., Any]]
20+
InnerAfterHook = Optional[Callable[..., Any]]
21+
22+
23+
def safe_call_hook(hook: Optional[Callable[..., Any]], *args: Any) -> None:
24+
"""
25+
Call a hook defensively: swallow hook exceptions to avoid breaking user code.
26+
"""
27+
if not callable(hook):
28+
return
29+
try:
30+
hook(*args)
31+
except Exception as e:
32+
logger.debug("mem0 hook raised and was swallowed: %s", e)
33+
34+
35+
def set_memory_hooks(
36+
wrapper: Any,
37+
*,
38+
memory_before_hook: MemoryBeforeHook = None,
39+
memory_after_hook: MemoryAfterHook = None,
40+
) -> None:
41+
"""
42+
Configure top-level memory hooks on a MemoryOperationWrapper instance.
43+
"""
44+
wrapper._memory_before_hook = memory_before_hook
45+
wrapper._memory_after_hook = memory_after_hook

instrumentation-loongsuite/loongsuite-instrumentation-mem0/tests/test_hooks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RerankerWrapper,
2323
VectorStoreWrapper,
2424
)
25+
from opentelemetry.instrumentation.mem0.types import set_memory_hooks
2526
from opentelemetry.sdk.trace import TracerProvider
2627
from opentelemetry.util.genai._extended_memory import MemoryInvocation
2728

@@ -56,7 +57,7 @@ def after(
5657
calls.append(("after", hook_context))
5758

5859
w = MemoryOperationWrapper(_DummyTelemetryHandler())
59-
w.set_hooks(memory_before_hook=before, memory_after_hook=after)
60+
set_memory_hooks(w, memory_before_hook=before, memory_after_hook=after)
6061

6162
# Cover helper: _normalize_call_parameters with positional args mapping
6263
def _sig(self, memory_id, data, *, user_id=None): # noqa: ARG001
@@ -165,7 +166,9 @@ def after_boom(
165166
):
166167
raise RuntimeError("boom2")
167168

168-
w.set_hooks(memory_before_hook=before_boom, memory_after_hook=after_boom)
169+
set_memory_hooks(
170+
w, memory_before_hook=before_boom, memory_after_hook=after_boom
171+
)
169172
assert (
170173
w._execute_with_handler(
171174
_fn_sync,
@@ -180,7 +183,7 @@ def after_boom(
180183
)
181184

182185
# exception path calls after_hook with exception
183-
w.set_hooks(memory_before_hook=before, memory_after_hook=after)
186+
set_memory_hooks(w, memory_before_hook=before, memory_after_hook=after)
184187

185188
def _fn_raises(*a, **k):
186189
raise ValueError("nope")

0 commit comments

Comments
 (0)