Skip to content

Commit a51e1a3

Browse files
authored
feat:mos product api dev (#63)
## Description Summary: API Development: Added comprehensive product API functionality with singleton pattern for memos and LLM integration, including user role management and persistent user handling. Bug Fixes & Optimization: Resolved critical issues with UserRole errors, CI code, chat memory management, and suggestion memory functionality. Feature Enhancements: Updated chat functionality, added UUID support, improved search data handling, and implemented bilingual support (Chinese/English) with memory optimization. Fix: #(issue) Reviewer: @(reviewer) ## Checklist: - [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 - [ ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常 - [ ] I have added necessary documentation (if applicable) | 我已添加必要的文档(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人
2 parents 5452668 + 87e4359 commit a51e1a3

File tree

13 files changed

+459
-84
lines changed

13 files changed

+459
-84
lines changed

poetry.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ fastapi = {extras = ["all"], version = "^0.115.12"}
2626
sentence-transformers = "^4.1.0"
2727
sqlalchemy = "^2.0.41"
2828
redis = "^6.2.0"
29+
pika = "^1.3.2"
2930
schedule = "^1.2.2"
3031

3132
[tool.poetry.group.dev]

src/memos/api/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
264264
"user": neo4j_config["user"],
265265
"password": neo4j_config["password"],
266266
"db_name": os.getenv(
267-
"NEO4J_DB_NAME", f"db{user_id.replace('-', '')}"
267+
"NEO4J_DB_NAME", f"memos{user_id.replace('-', '')}"
268268
), # , replace with
269269
"auto_create": neo4j_config["auto_create"],
270270
},

src/memos/api/product_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,8 @@
2626

2727
if __name__ == "__main__":
2828
import uvicorn
29-
30-
uvicorn.run(app, host="0.0.0.0", port=8001)
29+
import argparse
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument("--port", type=int, default=8001)
32+
args = parser.parse_args()
33+
uvicorn.run(app, host="0.0.0.0", port=args.port)

src/memos/api/product_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,10 @@ class SearchRequest(BaseRequest):
150150
user_id: str = Field(..., description="User ID")
151151
query: str = Field(..., description="Search query")
152152
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
153+
154+
155+
class SuggestionRequest(BaseRequest):
156+
"""Request model for getting suggestion queries."""
157+
158+
user_id: str = Field(..., description="User ID")
159+
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")

src/memos/api/routers/product_router.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SearchRequest,
1616
SearchResponse,
1717
SimpleResponse,
18+
SuggestionRequest,
1819
SuggestionResponse,
1920
UserRegisterRequest,
2021
UserRegisterResponse,
@@ -36,6 +37,7 @@ def get_mos_product_instance():
3637
global MOS_PRODUCT_INSTANCE
3738
if MOS_PRODUCT_INSTANCE is None:
3839
default_config = APIConfig.get_product_default_config()
40+
print(default_config)
3941
from memos.configs.mem_os import MOSConfig
4042

4143
mos_config = MOSConfig(**default_config)
@@ -85,7 +87,6 @@ async def register_user(user_req: UserRegisterRequest):
8587
logger.error(f"Failed to register user: {traceback.format_exc()}")
8688
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
8789

88-
8990
@router.get(
9091
"/suggestions/{user_id}", summary="Get suggestion queries", response_model=SuggestionResponse
9192
)
@@ -104,6 +105,25 @@ async def get_suggestion_queries(user_id: str):
104105
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
105106

106107

108+
@router.post("/suggestions", summary="Get suggestion queries with language", response_model=SuggestionResponse)
109+
async def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
110+
"""Get suggestion queries for a specific user with language preference."""
111+
try:
112+
mos_product = get_mos_product_instance()
113+
suggestions = mos_product.get_suggestion_query(
114+
user_id=suggestion_req.user_id,
115+
language=suggestion_req.language
116+
)
117+
return SuggestionResponse(
118+
message="Suggestions retrieved successfully", data={"query": suggestions}
119+
)
120+
except ValueError as err:
121+
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
122+
except Exception as err:
123+
logger.error(f"Failed to get suggestions: {traceback.format_exc()}")
124+
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
125+
126+
107127
@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
108128
async def get_all_memories(memory_req: GetMemoryRequest):
109129
"""Get all memories for a specific user."""
@@ -177,15 +197,19 @@ async def chat(chat_req: ChatRequest):
177197
try:
178198
mos_product = get_mos_product_instance()
179199

180-
def generate_chat_response():
200+
async def generate_chat_response():
181201
"""Generate chat response as SSE stream."""
182202
try:
183-
yield from mos_product.chat_with_references(
203+
import asyncio
204+
205+
for chunk in mos_product.chat_with_references(
184206
query=chat_req.query,
185207
user_id=chat_req.user_id,
186208
cube_id=chat_req.mem_cube_id,
187209
history=chat_req.history,
188-
)
210+
):
211+
yield chunk
212+
await asyncio.sleep(0.05) # 50ms delay between chunks
189213
except Exception as e:
190214
logger.error(f"Error in chat stream: {e}")
191215
error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"

src/memos/llms/hf.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from collections.abc import Generator
23

34
from transformers import (
45
AutoModelForCausalLM,
@@ -71,6 +72,24 @@ def generate(self, messages: MessageList, past_key_values: DynamicCache | None =
7172
else:
7273
return self._generate_with_cache(prompt, past_key_values)
7374

75+
def generate_stream(self, messages: MessageList, past_key_values: DynamicCache | None = None) -> Generator[str, None, None]:
76+
"""
77+
Generate a streaming response from the model.
78+
Args:
79+
messages (MessageList): Chat messages for prompt construction.
80+
past_key_values (DynamicCache | None): Optional KV cache for fast generation.
81+
Yields:
82+
str: Streaming model response chunks.
83+
"""
84+
prompt = self.tokenizer.apply_chat_template(
85+
messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt
86+
)
87+
logger.info(f"HFLLM streaming prompt: {prompt}")
88+
if past_key_values is None:
89+
yield from self._generate_full_stream(prompt)
90+
else:
91+
yield from self._generate_with_cache_stream(prompt, past_key_values)
92+
7493
def _generate_full(self, prompt: str) -> str:
7594
"""
7695
Generate output from scratch using the full prompt.
@@ -104,6 +123,71 @@ def _generate_full(self, prompt: str) -> str:
104123
else response
105124
)
106125

126+
def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
127+
"""
128+
Generate output from scratch using the full prompt with streaming.
129+
Args:
130+
prompt (str): The input prompt string.
131+
Yields:
132+
str: Streaming response chunks.
133+
"""
134+
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
135+
136+
# Get generation parameters
137+
max_new_tokens = getattr(self.config, "max_tokens", 128)
138+
do_sample = getattr(self.config, "do_sample", True)
139+
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
140+
141+
# Manual streaming generation
142+
input_length = inputs.input_ids.shape[1]
143+
generated_ids = inputs.input_ids.clone()
144+
accumulated_text = ""
145+
146+
for _ in range(max_new_tokens):
147+
# Forward pass
148+
with torch.no_grad():
149+
outputs = self.model(
150+
input_ids=generated_ids,
151+
use_cache=True,
152+
return_dict=True,
153+
)
154+
155+
# Get next token logits
156+
next_token_logits = outputs.logits[:, -1, :]
157+
158+
# Apply logits processors if sampling
159+
if do_sample:
160+
batch_size, _ = next_token_logits.size()
161+
dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=next_token_logits.device)
162+
filtered_logits = self.logits_processors(dummy_ids, next_token_logits)
163+
probs = torch.softmax(filtered_logits, dim=-1)
164+
next_token = torch.multinomial(probs, num_samples=1)
165+
else:
166+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
167+
168+
# Check for EOS token
169+
if self._should_stop(next_token):
170+
break
171+
172+
# Append new token
173+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
174+
175+
# Decode and yield the new token
176+
new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
177+
if new_token_text: # Only yield non-empty tokens
178+
accumulated_text += new_token_text
179+
180+
# Apply thinking tag removal if enabled
181+
if remove_think_prefix:
182+
processed_text = remove_thinking_tags(accumulated_text)
183+
# Only yield the difference (new content)
184+
if len(processed_text) > len(accumulated_text) - len(new_token_text):
185+
yield processed_text[len(accumulated_text) - len(new_token_text):]
186+
else:
187+
yield new_token_text
188+
else:
189+
yield new_token_text
190+
107191
def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
108192
"""
109193
Generate output incrementally using an existing KV cache.
@@ -137,6 +221,68 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
137221
else response
138222
)
139223

224+
def _generate_with_cache_stream(self, query: str, kv: DynamicCache) -> Generator[str, None, None]:
225+
"""
226+
Generate output incrementally using an existing KV cache with streaming.
227+
Args:
228+
query (str): The new user query string.
229+
kv (DynamicCache): The prefilled KV cache.
230+
Yields:
231+
str: Streaming response chunks.
232+
"""
233+
query_ids = self.tokenizer(
234+
query, return_tensors="pt", add_special_tokens=False
235+
).input_ids.to(self.model.device)
236+
237+
max_new_tokens = getattr(self.config, "max_tokens", 128)
238+
do_sample = getattr(self.config, "do_sample", True)
239+
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
240+
241+
# Initial forward pass
242+
logits, kv = self._prefill(query_ids, kv)
243+
next_token = self._select_next_token(logits)
244+
245+
# Yield first token
246+
first_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
247+
accumulated_text = ""
248+
if first_token_text:
249+
accumulated_text += first_token_text
250+
if remove_think_prefix:
251+
processed_text = remove_thinking_tags(accumulated_text)
252+
if len(processed_text) > len(accumulated_text) - len(first_token_text):
253+
yield processed_text[len(accumulated_text) - len(first_token_text):]
254+
else:
255+
yield first_token_text
256+
else:
257+
yield first_token_text
258+
259+
generated = [next_token]
260+
261+
# Continue generation
262+
for _ in range(max_new_tokens - 1):
263+
if self._should_stop(next_token):
264+
break
265+
logits, kv = self._prefill(next_token, kv)
266+
next_token = self._select_next_token(logits)
267+
268+
# Decode and yield the new token
269+
new_token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
270+
if new_token_text:
271+
accumulated_text += new_token_text
272+
273+
# Apply thinking tag removal if enabled
274+
if remove_think_prefix:
275+
processed_text = remove_thinking_tags(accumulated_text)
276+
# Only yield the difference (new content)
277+
if len(processed_text) > len(accumulated_text) - len(new_token_text):
278+
yield processed_text[len(accumulated_text) - len(new_token_text):]
279+
else:
280+
yield new_token_text
281+
else:
282+
yield new_token_text
283+
284+
generated.append(next_token)
285+
140286
@torch.no_grad()
141287
def _prefill(
142288
self, input_ids: torch.Tensor, kv: DynamicCache

0 commit comments

Comments
 (0)