Skip to content

Commit afe4477

Browse files
authored
Python: generate-title + conversations count (#5302 Day 3) (#5382)
## Summary Sub-PR A for #5302 Day 3: Python backend endpoints needed by desktop app. **Endpoints added:** - `POST /v2/chat/generate-title` — LLM-powered chat session title generation with graceful fallback - `GET /v1/conversations/count` — Firestore aggregation count with stream fallback **Database layer:** - `count_conversations()` — Firestore count() aggregation - `stream_conversations()` — generator for unbounded counting fallback ## Test plan - 23 unit tests across 2 test files (12 generate-title + 11 conversations-count) - Both files in `test.sh` - Coverage: success paths, auth, LLM fallback, boundary truncation (50/100/500 chars), status parsing/normalization, aggregation fallback parity, validation (empty, too many statuses) - All 23 tests passing ## Review cycle - CP7 reviewer: approved (3 rounds — fixed mutable default arg, statuses validation, stream fallback) - CP8 tester: approved (4 boundary tests added for coverage gaps) Closes part of #5302
2 parents c693ecc + 52213d8 commit afe4477

File tree

6 files changed

+446
-0
lines changed

6 files changed

+446
-0
lines changed

backend/database/conversations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@ def get_conversations(
219219
return conversations
220220

221221

222+
def count_conversations(uid: str, statuses: Optional[List[str]] = None) -> int:
223+
"""Count conversations matching status filters without fetching full documents."""
224+
if statuses is None:
225+
statuses = []
226+
conversations_ref = db.collection('users').document(uid).collection(conversations_collection)
227+
conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False))
228+
if statuses:
229+
conversations_ref = conversations_ref.where(filter=FieldFilter('status', 'in', statuses))
230+
count_query = conversations_ref.count()
231+
results = count_query.get()
232+
return results[0][0].value
233+
234+
235+
def stream_conversations(uid: str, statuses: Optional[List[str]] = None):
236+
"""Yield conversation docs as a stream for counting without loading all into memory."""
237+
if statuses is None:
238+
statuses = []
239+
conversations_ref = db.collection('users').document(uid).collection(conversations_collection)
240+
conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False))
241+
if statuses:
242+
conversations_ref = conversations_ref.where(filter=FieldFilter('status', 'in', statuses))
243+
yield from conversations_ref.stream()
244+
245+
222246
@prepare_for_read(decrypt_func=_prepare_conversation_for_read)
223247
def get_conversations_without_photos(
224248
uid: str,

backend/routers/chat.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
resolve_voice_message_language,
3434
transcribe_voice_message_segment,
3535
)
36+
from utils.llm.clients import llm_mini
3637
from utils.llm.persona import initial_persona_chat_message
3738
from utils.llm.chat import initial_chat_message
3839
from utils.llm.goals import extract_and_update_goal_progress
@@ -695,6 +696,50 @@ def rate_message(
695696
return StatusResponse(status='ok')
696697

697698

699+
class TitleMessageInput(BaseModel):
700+
text: str
701+
sender: str
702+
703+
704+
class GenerateTitleRequest(BaseModel):
705+
session_id: str
706+
messages: List[TitleMessageInput]
707+
708+
709+
class GenerateTitleResponse(BaseModel):
710+
title: str
711+
712+
713+
@router.post('/v2/chat/generate-title', response_model=GenerateTitleResponse, tags=['chat'])
714+
def generate_chat_title(
715+
request: GenerateTitleRequest,
716+
uid: str = Depends(auth.get_current_user_uid),
717+
):
718+
"""Desktop: generate a short title for a chat session from its messages."""
719+
if not request.messages:
720+
raise HTTPException(status_code=400, detail="messages list cannot be empty")
721+
722+
transcript = '\n'.join(f'{m.sender}: {m.text[:500]}' for m in request.messages[:10])
723+
prompt = (
724+
'Generate a short chat session title (max 6 words) summarising this conversation. '
725+
'Return ONLY the title text, no quotes or punctuation.\n\n' + transcript
726+
)
727+
try:
728+
result = llm_mini.invoke(prompt)
729+
title = result.content.strip().strip('"\'')[:100]
730+
except Exception as e:
731+
logger.warning(f'generate_chat_title LLM failed: {e}')
732+
title = request.messages[0].text[:50]
733+
734+
# Update session title if session exists
735+
try:
736+
chat_db.update_chat_session(uid, request.session_id, {'title': title, 'updated_at': datetime.now(timezone.utc)})
737+
except Exception as e:
738+
logger.warning(f'generate_chat_title update session failed: {e}')
739+
740+
return GenerateTitleResponse(title=title)
741+
742+
698743
# CLEANUP: Remove after new app goes to prod ----------------------------------------------------------
699744

700745

backend/routers/conversations.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,23 @@ def get_conversations(
249249
return conversations
250250

251251

252+
@router.get('/v1/conversations/count', tags=['conversations'])
253+
def get_conversations_count(
254+
statuses: Optional[str] = Query("processing,completed"),
255+
uid: str = Depends(auth.get_current_user_uid),
256+
):
257+
"""Count conversations matching optional status filters."""
258+
status_list = [s.strip() for s in statuses.split(',') if s.strip()] if statuses else []
259+
if len(status_list) > 10:
260+
raise HTTPException(status_code=400, detail="Too many status values (max 10)")
261+
try:
262+
count = conversations_db.count_conversations(uid, statuses=status_list)
263+
except Exception as e:
264+
logger.warning(f'count_conversations aggregation fallback: {e}')
265+
count = sum(1 for _ in conversations_db.stream_conversations(uid, statuses=status_list))
266+
return {'count': count}
267+
268+
252269
@router.get("/v1/conversations/{conversation_id}", response_model=Conversation, tags=['conversations'])
253270
def get_conversation_by_id(conversation_id: str, uid: str = Depends(auth.get_current_user_uid)):
254271
logger.info(f'get_conversation_by_id {uid} {conversation_id}')

backend/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,5 @@ pytest tests/unit/test_assistant_settings_ai_profile.py -v
4242
pytest tests/unit/test_focus_sessions.py -v
4343
pytest tests/unit/test_advice.py -v
4444
pytest tests/unit/test_staged_tasks.py -v
45+
pytest tests/unit/test_chat_generate_title.py -v
46+
pytest tests/unit/test_conversations_count.py -v
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import sys
2+
from datetime import datetime, timezone
3+
from unittest.mock import patch, MagicMock
4+
5+
import pytest
6+
7+
for mod_name in [
8+
'firebase_admin',
9+
'firebase_admin.auth',
10+
'firebase_admin.firestore',
11+
'firebase_admin.messaging',
12+
'google.cloud',
13+
'google.cloud.exceptions',
14+
'google.cloud.firestore',
15+
'google.cloud.firestore_v1',
16+
'google.cloud.firestore_v1.base_query',
17+
'google.cloud.firestore_v1.query',
18+
'google.cloud.storage',
19+
'google.cloud.storage.blob',
20+
'google.cloud.storage.bucket',
21+
'google.auth',
22+
'google.auth.transport',
23+
'google.auth.transport.requests',
24+
'google.oauth2',
25+
'google.oauth2.service_account',
26+
'pinecone',
27+
'typesense',
28+
'openai',
29+
'langchain_openai',
30+
]:
31+
sys.modules.setdefault(mod_name, MagicMock())
32+
33+
# Mock llm_mini before importing the router
34+
mock_llm = MagicMock()
35+
mock_llm.invoke.return_value = MagicMock(content='Project Discussion')
36+
sys.modules.setdefault('utils.llm.clients', MagicMock(llm_mini=mock_llm))
37+
38+
from routers.chat import router
39+
40+
from fastapi import FastAPI, HTTPException
41+
from fastapi.testclient import TestClient
42+
43+
44+
@pytest.fixture
45+
def client():
46+
with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
47+
app = FastAPI()
48+
app.include_router(router)
49+
with TestClient(app) as c:
50+
yield c
51+
52+
53+
@pytest.fixture
54+
def client_no_auth():
55+
app = FastAPI()
56+
app.include_router(router)
57+
with TestClient(app) as c:
58+
yield c
59+
60+
61+
AUTH = {"Authorization": "Bearer 123testuser"}
62+
63+
64+
class TestGenerateChatTitle:
65+
def test_generate_title_success(self, client):
66+
data = {
67+
"session_id": "sess-1",
68+
"messages": [
69+
{"text": "How do I deploy to production?", "sender": "human"},
70+
{"text": "You can use the CI/CD pipeline.", "sender": "ai"},
71+
],
72+
}
73+
with patch('routers.chat.llm_mini') as mock_llm:
74+
mock_llm.invoke.return_value = MagicMock(content='Production Deployment')
75+
with patch('routers.chat.chat_db.update_chat_session'):
76+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
77+
assert resp.status_code == 200
78+
assert resp.json()["title"] == "Production Deployment"
79+
80+
def test_generate_title_strips_quotes(self, client):
81+
data = {
82+
"session_id": "sess-1",
83+
"messages": [{"text": "Hello", "sender": "human"}],
84+
}
85+
with patch('routers.chat.llm_mini') as mock_llm:
86+
mock_llm.invoke.return_value = MagicMock(content='"Greeting Chat"')
87+
with patch('routers.chat.chat_db.update_chat_session'):
88+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
89+
assert resp.status_code == 200
90+
assert resp.json()["title"] == "Greeting Chat"
91+
92+
def test_generate_title_empty_messages_returns_400(self, client):
93+
data = {"session_id": "sess-1", "messages": []}
94+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
95+
assert resp.status_code == 400
96+
97+
def test_generate_title_no_messages_field_returns_422(self, client):
98+
data = {"session_id": "sess-1"}
99+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
100+
assert resp.status_code == 422
101+
102+
def test_generate_title_llm_fallback(self, client):
103+
data = {
104+
"session_id": "sess-1",
105+
"messages": [{"text": "What about the budget proposal?", "sender": "human"}],
106+
}
107+
with patch('routers.chat.llm_mini') as mock_llm:
108+
mock_llm.invoke.side_effect = Exception("LLM down")
109+
with patch('routers.chat.chat_db.update_chat_session'):
110+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
111+
assert resp.status_code == 200
112+
assert resp.json()["title"] == "What about the budget proposal?"
113+
114+
def test_generate_title_updates_session(self, client):
115+
data = {
116+
"session_id": "sess-1",
117+
"messages": [{"text": "Hello", "sender": "human"}],
118+
}
119+
with patch('routers.chat.llm_mini') as mock_llm:
120+
mock_llm.invoke.return_value = MagicMock(content='Greeting')
121+
with patch('routers.chat.chat_db.update_chat_session') as mock_update:
122+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
123+
assert resp.status_code == 200
124+
mock_update.assert_called_once()
125+
call_args = mock_update.call_args[0]
126+
assert call_args[1] == 'sess-1'
127+
assert call_args[2]['title'] == 'Greeting'
128+
129+
def test_generate_title_session_update_failure_still_returns(self, client):
130+
data = {
131+
"session_id": "sess-1",
132+
"messages": [{"text": "Hello", "sender": "human"}],
133+
}
134+
with patch('routers.chat.llm_mini') as mock_llm:
135+
mock_llm.invoke.return_value = MagicMock(content='Greeting')
136+
with patch('routers.chat.chat_db.update_chat_session', side_effect=Exception("DB err")):
137+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
138+
assert resp.status_code == 200
139+
assert resp.json()["title"] == "Greeting"
140+
141+
def test_generate_title_truncates_long_title(self, client):
142+
data = {
143+
"session_id": "sess-1",
144+
"messages": [{"text": "Hello", "sender": "human"}],
145+
}
146+
with patch('routers.chat.llm_mini') as mock_llm:
147+
mock_llm.invoke.return_value = MagicMock(content='A' * 200)
148+
with patch('routers.chat.chat_db.update_chat_session'):
149+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
150+
assert resp.status_code == 200
151+
assert len(resp.json()["title"]) <= 100
152+
153+
def test_generate_title_no_auth_returns_401(self, client_no_auth):
154+
data = {
155+
"session_id": "sess-1",
156+
"messages": [{"text": "Hello", "sender": "human"}],
157+
}
158+
with patch(
159+
'routers.chat.auth.get_current_user_uid',
160+
side_effect=HTTPException(status_code=401, detail='Not authenticated'),
161+
):
162+
resp = client_no_auth.post("/v2/chat/generate-title", json=data)
163+
assert resp.status_code == 401
164+
165+
def test_generate_title_limits_messages(self, client):
166+
"""Only first 10 messages should be sent to LLM."""
167+
data = {
168+
"session_id": "sess-1",
169+
"messages": [{"text": f"Message {i}", "sender": "human"} for i in range(20)],
170+
}
171+
with patch('routers.chat.llm_mini') as mock_llm:
172+
mock_llm.invoke.return_value = MagicMock(content='Long Chat')
173+
with patch('routers.chat.chat_db.update_chat_session'):
174+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
175+
assert resp.status_code == 200
176+
prompt = mock_llm.invoke.call_args[0][0]
177+
assert 'Message 9' in prompt
178+
assert 'Message 10' not in prompt
179+
180+
def test_generate_title_fallback_truncates_to_50_chars(self, client):
181+
"""When LLM fails, fallback title is truncated to 50 chars."""
182+
long_text = 'A' * 100
183+
data = {
184+
"session_id": "sess-1",
185+
"messages": [{"text": long_text, "sender": "human"}],
186+
}
187+
with patch('routers.chat.llm_mini') as mock_llm:
188+
mock_llm.invoke.side_effect = Exception("LLM down")
189+
with patch('routers.chat.chat_db.update_chat_session'):
190+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
191+
assert resp.status_code == 200
192+
assert len(resp.json()["title"]) == 50
193+
194+
def test_generate_title_truncates_message_text_to_500_chars(self, client):
195+
"""Each message text is truncated to 500 chars in the transcript sent to LLM."""
196+
long_text = 'B' * 1000
197+
data = {
198+
"session_id": "sess-1",
199+
"messages": [{"text": long_text, "sender": "human"}],
200+
}
201+
with patch('routers.chat.llm_mini') as mock_llm:
202+
mock_llm.invoke.return_value = MagicMock(content='Title')
203+
with patch('routers.chat.chat_db.update_chat_session'):
204+
resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
205+
assert resp.status_code == 200
206+
prompt = mock_llm.invoke.call_args[0][0]
207+
# The transcript line should contain exactly 500 B's, not 1000
208+
assert 'B' * 500 in prompt
209+
assert 'B' * 501 not in prompt

0 commit comments

Comments
 (0)