Skip to content

Commit 1f645a4

Browse files
author
harvey_xiang
committed
feat: add model_name
1 parent e513763 commit 1f645a4

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

src/memos/llms/openai.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def __init__(self, config: OpenAILLMConfig):
2828
)
2929
logger.info("OpenAI LLM instance initialized")
3030

31-
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
31+
@timed_with_status(
32+
log_prefix="OpenAI LLM",
33+
log_extra_args=lambda self, messages, **kwargs: {
34+
"model_name_or_path": kwargs.get("model_name_or_path", self.config.model_name_or_path)
35+
},
36+
)
3237
def generate(self, messages: MessageList, **kwargs) -> str:
3338
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
34-
logger.info(
35-
f"LLM Model: {self.config.model_name_or_path} {kwargs.get('model_name_or_path')}"
36-
)
3739
response = self.client.chat.completions.create(
3840
model=kwargs.get("model_name_or_path", self.config.model_name_or_path),
3941
messages=messages,
@@ -58,7 +60,12 @@ def generate(self, messages: MessageList, **kwargs) -> str:
5860
return reasoning_content + response_content
5961
return response_content
6062

61-
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
63+
@timed_with_status(
64+
log_prefix="OpenAI LLM",
65+
log_extra_args=lambda self, messages, **kwargs: {
66+
"model_name_or_path": self.config.model_name_or_path
67+
},
68+
)
6269
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
6370
"""Stream response from OpenAI LLM with optional reasoning support."""
6471
if kwargs.get("tools"):

src/memos/utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ def timed_with_status(
1919
Parameters:
2020
- log: enable timing logs (default True)
2121
- log_prefix: prefix; falls back to function name
22-
- log_args: names to include in logs (str or list/tuple of str).
23-
- log_extra_args: extra arguments to include in logs (dict).
22+
- log_args: names to include in logs (str or list/tuple of str), values are taken from kwargs by name.
23+
- log_extra_args:
24+
- can be a dict: fixed contextual fields that are always attached to logs;
25+
- or a callable: like `fn(*args, **kwargs) -> dict`, used to dynamically generate contextual fields at runtime.
2426
"""
2527

2628
if isinstance(log_args, str):
@@ -51,12 +53,24 @@ def wrapper(*args, **kwargs):
5153
elapsed_ms = (time.perf_counter() - start) * 1000.0
5254

5355
ctx_parts = []
56+
# 1) Collect parameters from kwargs by name
5457
for key in effective_log_args:
5558
val = kwargs.get(key)
5659
ctx_parts.append(f"{key}={val}")
5760

58-
if log_extra_args:
59-
ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items())
61+
# 2) Support log_extra_args as dict or callable, so we can dynamically
62+
# extract values from self or other runtime context
63+
extra_items = {}
64+
try:
65+
if callable(log_extra_args):
66+
extra_items = log_extra_args(*args, **kwargs) or {}
67+
elif isinstance(log_extra_args, dict):
68+
extra_items = log_extra_args
69+
except Exception as e:
70+
logger.warning(f"[TIMER_WITH_STATUS] log_extra_args callback error: {e!r}")
71+
72+
if extra_items:
73+
ctx_parts.extend(f"{key}={val}" for key, val in extra_items.items())
6074

6175
ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else ""
6276

0 commit comments

Comments
 (0)