Skip to content

Commit 623da17

Browse files
author
harvey_xiang
committed
feat: timer add log args
1 parent ac9af5f commit 623da17

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/memos/llms/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, config: OpenAILLMConfig):
2828
)
2929
logger.info("OpenAI LLM instance initialized")
3030

31-
@timed(log=True, log_prefix="OpenAI LLM")
31+
@timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
3232
def generate(self, messages: MessageList, **kwargs) -> str:
3333
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
3434
response = self.client.chat.completions.create(

src/memos/utils.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import time
23

34
from memos.log import get_logger
@@ -6,20 +7,42 @@
67
logger = get_logger(__name__)
78

89

9-
def timed(func=None, *, log=True, log_prefix=""):
10-
"""Decorator to measure and optionally log time of retrieval steps.
11-
12-
Can be used as @timed or @timed(log=True)
10+
def timed(func=None, *, log=True, log_prefix="", log_args=None):
11+
"""
12+
Parameters:
13+
- log: enable timing logs (default True)
14+
- log_prefix: prefix; falls back to function name
15+
- log_args: names to include in logs (str or list/tuple of str).
16+
Value priority: kwargs → args[0].config.<name> (if available).
17+
Non-string items are ignored.
18+
19+
Examples:
20+
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path", "temperature"])
21+
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["temperature"])
22+
- @timed() # defaults
1323
"""
1424

1525
def decorator(fn):
26+
@functools.wraps(fn)
1627
def wrapper(*args, **kwargs):
1728
start = time.perf_counter()
1829
result = fn(*args, **kwargs)
19-
elapsed = time.perf_counter() - start
20-
elapsed_ms = elapsed * 1000.0
21-
if log:
22-
logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")
30+
elapsed_ms = (time.perf_counter() - start) * 1000.0
31+
ctx_str = ""
32+
33+
if log is not True:
34+
return result
35+
36+
if log_args:
37+
ctx_parts = []
38+
for key in log_args:
39+
val = kwargs.get(key)
40+
ctx_parts.append(f"{key}={val}")
41+
ctx_str = f" [{', '.join(ctx_parts)}]"
42+
logger.info(
43+
f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms, args: {ctx_str}"
44+
)
45+
2346
return result
2447

2548
return wrapper

0 commit comments

Comments
 (0)