Skip to content

Commit 316864b

Browse files
CarltonXiangharvey_xiangCaralHsifridayL
authored
Feat/time status (#608)
* feat: timer add log args * feat: timer add log args * feat: timer add log args * feat: add openai model log * feat: add timed_with_status * feat: add openai model log * fix: conflict --------- Co-authored-by: harvey_xiang <[email protected]> Co-authored-by: CaralHsi <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent 1727070 commit 316864b

File tree

4 files changed

+123
-84
lines changed

4 files changed

+123
-84
lines changed

src/memos/embedders/universal_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from memos.configs.embedder import UniversalAPIEmbedderConfig
55
from memos.embedders.base import BaseEmbedder
66
from memos.log import get_logger
7-
from memos.utils import timed
7+
from memos.utils import timed_with_status
88

99

1010
logger = get_logger(__name__)
@@ -30,8 +30,7 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
3030
else:
3131
raise ValueError(f"Embeddings unsupported provider: {self.provider}")
3232

33-
@timed(
34-
log=True,
33+
@timed_with_status(
3534
log_prefix="model_timed_embedding",
3635
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
3736
)

src/memos/llms/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from memos.llms.utils import remove_thinking_tags
1313
from memos.log import get_logger
1414
from memos.types import MessageList
15-
from memos.utils import timed
15+
from memos.utils import timed_with_status
1616

1717

1818
logger = get_logger(__name__)
@@ -28,7 +28,7 @@ def __init__(self, config: OpenAILLMConfig):
2828
)
2929
logger.info("OpenAI LLM instance initialized")
3030

31-
@timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
31+
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
3232
def generate(self, messages: MessageList, **kwargs) -> str:
3333
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
3434
response = self.client.chat.completions.create(
@@ -55,7 +55,7 @@ def generate(self, messages: MessageList, **kwargs) -> str:
5555
return reasoning_content + response_content
5656
return response_content
5757

58-
@timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
58+
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
5959
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
6060
"""Stream response from OpenAI LLM with optional reasoning support."""
6161
if kwargs.get("tools"):

src/memos/reranker/http_bge.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import requests
1010

1111
from memos.log import get_logger
12-
from memos.utils import timed
12+
from memos.utils import timed_with_status
1313

1414
from .base import BaseReranker
1515
from .concat import concat_original_source
@@ -119,8 +119,12 @@ def __init__(
119119
self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys)
120120
self._warned_missing_keys: set[str] = set()
121121

122-
@timed(
123-
log=True, log_prefix="model_timed_rerank", log_extra_args={"model_name_or_path": "reranker"}
122+
@timed_with_status(
123+
log_prefix="model_timed_rerank",
124+
log_extra_args={"model_name_or_path": "reranker"},
125+
fallback=lambda exc, self, query, graph_results, top_k, *a, **kw: [
126+
(item, 0.0) for item in graph_results[:top_k]
127+
],
124128
)
125129
def rerank(
126130
self,
@@ -150,6 +154,7 @@ def rerank(
150154
list[tuple[TextualMemoryItem, float]]
151155
Re-ranked items with scores, sorted descending by score.
152156
"""
157+
153158
if not graph_results:
154159
return []
155160

@@ -173,63 +178,54 @@ def rerank(
173178
headers = {"Content-Type": "application/json", **self.headers_extra}
174179
payload = {"model": self.model, "query": query, "documents": documents}
175180

176-
try:
177-
# Make the HTTP request to the reranker service
178-
resp = requests.post(
179-
self.reranker_url, headers=headers, json=payload, timeout=self.timeout
180-
)
181-
resp.raise_for_status()
182-
data = resp.json()
183-
184-
scored_items: list[tuple[TextualMemoryItem, float]] = []
185-
186-
if "results" in data:
187-
# Format:
188-
# dict("results": [{"index": int, "relevance_score": float},
189-
# ...])
190-
rows = data.get("results", [])
191-
for r in rows:
192-
idx = r.get("index")
193-
# The returned index refers to 'documents' (i.e., our 'pairs' order),
194-
# so we must map it back to the original graph_results index.
195-
if isinstance(idx, int) and 0 <= idx < len(graph_results):
196-
raw_score = float(r.get("relevance_score", r.get("score", 0.0)))
197-
item = graph_results[idx]
198-
# generic boost
199-
score = self._apply_boost_generic(item, raw_score, search_priority)
200-
scored_items.append((item, score))
201-
202-
scored_items.sort(key=lambda x: x[1], reverse=True)
203-
return scored_items[: min(top_k, len(scored_items))]
204-
205-
elif "data" in data:
206-
# Format: {"data": [{"score": float}, ...]} aligned by list order
207-
rows = data.get("data", [])
208-
# Build a list of scores aligned with our 'documents' (pairs)
209-
score_list = [float(r.get("score", 0.0)) for r in rows]
210-
211-
if len(score_list) < len(graph_results):
212-
score_list += [0.0] * (len(graph_results) - len(score_list))
213-
elif len(score_list) > len(graph_results):
214-
score_list = score_list[: len(graph_results)]
215-
216-
scored_items = []
217-
for item, raw_score in zip(graph_results, score_list, strict=False):
181+
# Make the HTTP request to the reranker service
182+
resp = requests.post(self.reranker_url, headers=headers, json=payload, timeout=self.timeout)
183+
resp.raise_for_status()
184+
data = resp.json()
185+
186+
scored_items: list[tuple[TextualMemoryItem, float]] = []
187+
188+
if "results" in data:
189+
# Format:
190+
# dict("results": [{"index": int, "relevance_score": float},
191+
# ...])
192+
rows = data.get("results", [])
193+
for r in rows:
194+
idx = r.get("index")
195+
# The returned index refers to 'documents' (i.e., our 'pairs' order),
196+
# so we must map it back to the original graph_results index.
197+
if isinstance(idx, int) and 0 <= idx < len(graph_results):
198+
raw_score = float(r.get("relevance_score", r.get("score", 0.0)))
199+
item = graph_results[idx]
200+
# generic boost
218201
score = self._apply_boost_generic(item, raw_score, search_priority)
219202
scored_items.append((item, score))
220203

221-
scored_items.sort(key=lambda x: x[1], reverse=True)
222-
return scored_items[: min(top_k, len(scored_items))]
204+
scored_items.sort(key=lambda x: x[1], reverse=True)
205+
return scored_items[: min(top_k, len(scored_items))]
206+
207+
elif "data" in data:
208+
# Format: {"data": [{"score": float}, ...]} aligned by list order
209+
rows = data.get("data", [])
210+
# Build a list of scores aligned with our 'documents' (pairs)
211+
score_list = [float(r.get("score", 0.0)) for r in rows]
212+
213+
if len(score_list) < len(graph_results):
214+
score_list += [0.0] * (len(graph_results) - len(score_list))
215+
elif len(score_list) > len(graph_results):
216+
score_list = score_list[: len(graph_results)]
223217

224-
else:
225-
# Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs
226-
# Note: we use 'pairs' to keep alignment with valid (string) docs.
227-
return [(item, 0.0) for item in graph_results[:top_k]]
218+
scored_items = []
219+
for item, raw_score in zip(graph_results, score_list, strict=False):
220+
score = self._apply_boost_generic(item, raw_score, search_priority)
221+
scored_items.append((item, score))
228222

229-
except Exception as e:
230-
# Network error, timeout, JSON decode error, etc.
231-
# Degrade gracefully by returning first top_k valid docs with 0.0 score.
232-
logger.error(f"[HTTPBGEReranker] request failed: {e}")
223+
scored_items.sort(key=lambda x: x[1], reverse=True)
224+
return scored_items[: min(top_k, len(scored_items))]
225+
226+
else:
227+
# Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs
228+
# Note: we use 'pairs' to keep alignment with valid (string) docs.
233229
return [(item, 0.0) for item in graph_results[:top_k]]
234230

235231
def _get_attr_or_key(self, obj: Any, key: str) -> Any:

src/memos/utils.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import time
23

34
from memos.log import get_logger
@@ -6,47 +7,90 @@
67
logger = get_logger(__name__)
78

89

9-
def timed(func=None, *, log=True, log_prefix="", log_args=None, log_extra_args=None):
10+
def timed_with_status(
11+
func=None,
12+
*,
13+
log_prefix="",
14+
log_args=None,
15+
log_extra_args=None,
16+
fallback=None,
17+
):
1018
"""
1119
Parameters:
1220
- log: enable timing logs (default True)
1321
- log_prefix: prefix; falls back to function name
1422
- log_args: names to include in logs (str or list/tuple of str).
15-
Value priority: kwargs → args[0].config.<name> (if available).
16-
Non-string items are ignored.
17-
18-
Examples:
19-
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path", "temperature"])
20-
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["temperature"])
21-
- @timed() # defaults
23+
- log_extra_args: extra arguments to include in logs (dict).
2224
"""
2325

26+
if isinstance(log_args, str):
27+
effective_log_args = [log_args]
28+
else:
29+
effective_log_args = list(log_args) if log_args else []
30+
2431
def decorator(fn):
32+
@functools.wraps(fn)
2533
def wrapper(*args, **kwargs):
2634
start = time.perf_counter()
27-
result = fn(*args, **kwargs)
28-
elapsed_ms = (time.perf_counter() - start) * 1000.0
29-
ctx_str = ""
30-
ctx_parts = []
35+
exc_type = None
36+
result = None
37+
success_flag = False
3138

32-
if log is not True:
39+
try:
40+
result = fn(*args, **kwargs)
41+
success_flag = True
3342
return result
43+
except Exception as e:
44+
exc_type = type(e)
45+
success_flag = False
46+
47+
if fallback is not None and callable(fallback):
48+
result = fallback(e, *args, **kwargs)
49+
return result
50+
finally:
51+
elapsed_ms = (time.perf_counter() - start) * 1000.0
3452

35-
if log_args:
36-
for key in log_args:
53+
ctx_parts = []
54+
for key in effective_log_args:
3755
val = kwargs.get(key)
3856
ctx_parts.append(f"{key}={val}")
39-
ctx_str = f" [{', '.join(ctx_parts)}]"
4057

41-
if log_extra_args:
42-
ctx_parts.extend([f"{key}={val}" for key, val in log_extra_args.items()])
58+
if log_extra_args:
59+
ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items())
60+
61+
ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else ""
62+
63+
status = "SUCCESS" if success_flag else "FAILED"
64+
status_info = f", status: {status}"
65+
66+
if not success_flag and exc_type is not None:
67+
status_info += f", error: {exc_type.__name__}"
68+
69+
msg = (
70+
f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} "
71+
f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}"
72+
)
73+
74+
logger.info(msg)
75+
76+
return wrapper
77+
78+
if func is None:
79+
return decorator
80+
return decorator(func)
4381

44-
if ctx_parts:
45-
ctx_str = f" [{', '.join(ctx_parts)}]"
4682

47-
logger.info(
48-
f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms, args: {ctx_str}"
49-
)
83+
def timed(func=None, *, log=True, log_prefix=""):
84+
def decorator(fn):
85+
def wrapper(*args, **kwargs):
86+
start = time.perf_counter()
87+
result = fn(*args, **kwargs)
88+
elapsed_ms = (time.perf_counter() - start) * 1000.0
89+
90+
if log is not True:
91+
return result
92+
93+
logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")
5094

5195
return result
5296

0 commit comments

Comments
 (0)