diff --git a/backend/database/advice.py b/backend/database/advice.py
new file mode 100644
index 0000000000..5043c9cb72
--- /dev/null
+++ b/backend/database/advice.py
@@ -0,0 +1,115 @@
+import logging
+import uuid
+from datetime import datetime, timezone
+from typing import List, Dict, Any, Optional
+
+from google.cloud import firestore
+
+from ._client import db
+
+logger = logging.getLogger(__name__)
+
+USERS_COLLECTION = 'users'
+ADVICE_SUBCOLLECTION = 'advice'
+
+
+def _collection_ref(uid: str):
+ return db.collection(USERS_COLLECTION).document(uid).collection(ADVICE_SUBCOLLECTION)
+
+
+def create_advice(uid: str, data: Dict[str, Any]) -> Dict[str, Any]:
+ """Create a new advice document. Returns the created document with id."""
+ advice_id = str(uuid.uuid4())
+ now = datetime.now(timezone.utc)
+
+ doc_data = {
+ 'content': data['content'],
+ 'category': data.get('category', 'other'),
+ 'confidence': data.get('confidence', 0.5),
+ 'is_read': False,
+ 'is_dismissed': False,
+ 'created_at': now,
+ }
+ for optional_field in ('reasoning', 'source_app', 'context_summary', 'current_activity'):
+ if data.get(optional_field) is not None:
+ doc_data[optional_field] = data[optional_field]
+
+ _collection_ref(uid).document(advice_id).set(doc_data)
+
+ doc_data['id'] = advice_id
+ return doc_data
+
+
+def get_advice(
+ uid: str,
+ limit: int = 100,
+ offset: int = 0,
+ category: Optional[str] = None,
+ include_dismissed: bool = False,
+) -> List[Dict[str, Any]]:
+ """Query advice, ordered by created_at DESC."""
+ query = _collection_ref(uid).order_by('created_at', direction=firestore.Query.DESCENDING)
+
+ if not include_dismissed:
+ query = query.where(filter=firestore.FieldFilter('is_dismissed', '==', False))
+ if category:
+ query = query.where(filter=firestore.FieldFilter('category', '==', category))
+
+ query = query.offset(offset).limit(limit)
+
+ results = []
+ for doc in query.stream():
+ data = doc.to_dict()
+ data['id'] = doc.id
+ results.append(data)
+ return results
+
+
+def update_advice(uid: str, advice_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ """Update an advice document (is_read, is_dismissed). Returns updated doc."""
+ doc_ref = _collection_ref(uid).document(advice_id)
+
+ update_data = {'updated_at': datetime.now(timezone.utc)}
+ if 'is_read' in data:
+ update_data['is_read'] = data['is_read']
+ if 'is_dismissed' in data:
+ update_data['is_dismissed'] = data['is_dismissed']
+
+ try:
+ doc_ref.update(update_data)
+ except Exception as e:
+ if hasattr(e, 'code') and e.code == 404:
+ return None
+ raise
+
+ doc = doc_ref.get()
+ if doc.exists:
+ result = doc.to_dict()
+ result['id'] = doc.id
+ return result
+ return None
+
+
+def delete_advice(uid: str, advice_id: str) -> bool:
+ """Delete an advice document. Returns True on success."""
+ _collection_ref(uid).document(advice_id).delete()
+ return True
+
+
+def mark_all_advice_read(uid: str) -> int:
+ """Mark all unread, non-dismissed advice as read. Returns count of marked items."""
+ query = _collection_ref(uid).where(
+ filter=firestore.FieldFilter('is_dismissed', '==', False)
+ ).where(
+ filter=firestore.FieldFilter('is_read', '==', False)
+ ).limit(1000)
+
+ count = 0
+ now = datetime.now(timezone.utc)
+ for doc in query.stream():
+ try:
+ doc.reference.update({'is_read': True, 'updated_at': now})
+ count += 1
+ except Exception:
+ logger.warning('Failed to mark advice %s as read for uid=%s', doc.id, uid)
+ return count
diff --git a/backend/database/chat.py b/backend/database/chat.py
index 68ebcf51fc..0cd1b887a2 100644
--- a/backend/database/chat.py
+++ b/backend/database/chat.py
@@ -468,10 +468,69 @@ def delete_chat_session(uid, chat_session_id):
session_ref.delete()
-def add_message_to_chat_session(uid: str, chat_session_id: str, message_id: str):
+def get_chat_sessions(
+ uid: str, app_id: Optional[str] = None, limit: int = 50, offset: int = 0, starred: Optional[bool] = None
+):
+ """List chat sessions with optional filters.
+
+ Note: Client-side sort + slice because Firestore composite indexes would be
+ needed for every filter combination. Acceptable for desktop users (low session
+ counts). Revisit with server-side ordering if session counts grow large.
+ """
+ sessions_ref = db.collection('users').document(uid).collection('chat_sessions')
+ if app_id is not None:
+ sessions_ref = sessions_ref.where(filter=FieldFilter('plugin_id', '==', app_id))
+ if starred is not None:
+ sessions_ref = sessions_ref.where(filter=FieldFilter('starred', '==', starred))
+ sessions = [doc.to_dict() for doc in sessions_ref.stream()]
+ sessions.sort(key=lambda s: s.get('updated_at', s.get('created_at', datetime.min)), reverse=True)
+ return sessions[offset : offset + limit]
+
+
+def update_chat_session(uid: str, chat_session_id: str, update_data: dict):
+ """Partial update of a chat session."""
+ user_ref = db.collection('users').document(uid)
+ session_ref = user_ref.collection('chat_sessions').document(chat_session_id)
+ session_ref.update(update_data)
+
+
+@set_data_protection_level(data_arg_name='message_data')
+@prepare_for_write(data_arg_name='message_data', prepare_func=_prepare_data_for_write)
+def save_message(uid: str, message_data: dict):
+ """Save a message directly by document ID (for desktop CRUD)."""
+ user_ref = db.collection('users').document(uid)
+ user_ref.collection('messages').document(message_data['id']).set(message_data)
+ return message_data
+
+
+def delete_chat_session_messages(uid: str, chat_session_id: str):
+ """Delete all messages belonging to a chat session."""
+ user_ref = db.collection('users').document(uid)
+ messages_ref = user_ref.collection('messages').where(filter=FieldFilter('chat_session_id', '==', chat_session_id))
+ batch = db.batch()
+ count = 0
+ for doc in messages_ref.stream():
+ batch.delete(doc.reference)
+ count += 1
+ if count % 400 == 0:
+ batch.commit()
+ batch = db.batch()
+ if count % 400 != 0:
+ batch.commit()
+ logger.info(f"Deleted {count} messages for session {chat_session_id}")
+
+
+def add_message_to_chat_session(uid: str, chat_session_id: str, message_id: str, preview: str = None):
user_ref = db.collection('users').document(uid)
session_ref = user_ref.collection('chat_sessions').document(chat_session_id)
- session_ref.update({"message_ids": firestore.ArrayUnion([message_id])})
+ update_data = {
+ "message_ids": firestore.ArrayUnion([message_id]),
+ "updated_at": datetime.now(timezone.utc),
+ "message_count": firestore.Increment(1),
+ }
+ if preview:
+ update_data["preview"] = preview[:200]
+ session_ref.update(update_data)
def add_files_to_chat_session(uid: str, chat_session_id: str, file_ids: List[str]):
diff --git a/backend/database/conversations.py b/backend/database/conversations.py
index 1d9896ff44..8d4cce69b9 100644
--- a/backend/database/conversations.py
+++ b/backend/database/conversations.py
@@ -220,6 +220,30 @@ def get_conversations(
return conversations
+def count_conversations(uid: str, statuses: Optional[List[str]] = None) -> int:
+ """Count conversations matching status filters without fetching full documents."""
+ if statuses is None:
+ statuses = []
+ conversations_ref = db.collection('users').document(uid).collection(conversations_collection)
+ conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False))
+ if statuses:
+ conversations_ref = conversations_ref.where(filter=FieldFilter('status', 'in', statuses))
+ count_query = conversations_ref.count()
+ results = count_query.get()
+ return results[0][0].value
+
+
+def stream_conversations(uid: str, statuses: Optional[List[str]] = None):
+ """Yield conversation docs as a stream for counting without loading all into memory."""
+ if statuses is None:
+ statuses = []
+ conversations_ref = db.collection('users').document(uid).collection(conversations_collection)
+ conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False))
+ if statuses:
+ conversations_ref = conversations_ref.where(filter=FieldFilter('status', 'in', statuses))
+ yield from conversations_ref.stream()
+
+
@prepare_for_read(decrypt_func=_prepare_conversation_for_read)
def get_conversations_without_photos(
uid: str,
diff --git a/backend/database/focus_sessions.py b/backend/database/focus_sessions.py
new file mode 100644
index 0000000000..c21af103ab
--- /dev/null
+++ b/backend/database/focus_sessions.py
@@ -0,0 +1,75 @@
+import logging
+import uuid
+from datetime import datetime, timedelta, timezone
+from typing import List, Dict, Any, Optional
+
+from google.cloud import firestore
+
+from ._client import db
+
+logger = logging.getLogger(__name__)
+
+USERS_COLLECTION = 'users'
+FOCUS_SESSIONS_SUBCOLLECTION = 'focus_sessions'
+
+
+def _collection_ref(uid: str):
+ return db.collection(USERS_COLLECTION).document(uid).collection(FOCUS_SESSIONS_SUBCOLLECTION)
+
+
+def create_focus_session(uid: str, data: Dict[str, Any]) -> Dict[str, Any]:
+ """Create a new focus session document. Returns the created document with id."""
+ session_id = str(uuid.uuid4())
+ now = datetime.now(timezone.utc)
+
+ doc_data = {
+ 'status': data['status'],
+ 'app_or_site': data['app_or_site'],
+ 'description': data['description'],
+ 'created_at': now,
+ }
+ if data.get('message') is not None:
+ doc_data['message'] = data['message']
+ if data.get('duration_seconds') is not None:
+ doc_data['duration_seconds'] = data['duration_seconds']
+
+ _collection_ref(uid).document(session_id).set(doc_data)
+
+ doc_data['id'] = session_id
+ return doc_data
+
+
+def get_focus_sessions(
+ uid: str,
+ limit: int = 100,
+ offset: int = 0,
+ date: Optional[str] = None,
+) -> List[Dict[str, Any]]:
+ """Query focus sessions, ordered by created_at DESC. Optional date filter (YYYY-MM-DD)."""
+ query = _collection_ref(uid).order_by('created_at', direction=firestore.Query.DESCENDING)
+
+ if date:
+ day_start = datetime.strptime(date, '%Y-%m-%d').replace(tzinfo=timezone.utc)
+ next_day_start = day_start + timedelta(days=1)
+ query = query.where(filter=firestore.FieldFilter('created_at', '>=', day_start))
+ query = query.where(filter=firestore.FieldFilter('created_at', '<', next_day_start))
+
+ query = query.offset(offset).limit(limit)
+
+ results = []
+ for doc in query.stream():
+ data = doc.to_dict()
+ data['id'] = doc.id
+ results.append(data)
+ return results
+
+
+def delete_focus_session(uid: str, session_id: str) -> bool:
+ """Delete a focus session document. Returns True on success."""
+ _collection_ref(uid).document(session_id).delete()
+ return True
+
+
+def get_focus_sessions_for_stats(uid: str, date: str) -> List[Dict[str, Any]]:
+ """Get up to 1000 sessions for a date, for stats computation."""
+ return get_focus_sessions(uid, limit=1000, offset=0, date=date)
diff --git a/backend/database/staged_tasks.py b/backend/database/staged_tasks.py
new file mode 100644
index 0000000000..28f8c35c18
--- /dev/null
+++ b/backend/database/staged_tasks.py
@@ -0,0 +1,260 @@
+"""Database operations for desktop staged tasks (users/{uid}/staged_tasks)."""
+
+from datetime import datetime, timezone
+from typing import Optional, List, Tuple
+
+from google.cloud import firestore
+
+from ._client import db
+import logging
+
+logger = logging.getLogger(__name__)
+
+COLLECTION = 'staged_tasks'
+
+
+def _prepare_for_read(data: dict) -> dict:
+ """Convert Firestore timestamps to Python datetimes."""
+ for field in ['created_at', 'updated_at', 'due_at', 'completed_at', 'deleted_at']:
+ if field in data and data[field] and hasattr(data[field], 'timestamp'):
+ data[field] = datetime.fromtimestamp(data[field].timestamp(), tz=timezone.utc)
+ return data
+
+
+# --- CREATE ---
+
+
+def create_staged_task(uid: str, data: dict) -> dict:
+ """Create a staged task with dedup. Returns existing item if description matches (case-insensitive)."""
+ description = data.get('description', '').strip()
+ if not description:
+ raise ValueError('description must not be empty')
+
+ ref = db.collection('users').document(uid).collection(COLLECTION)
+
+ # Dedup: check for existing task with same description (case-insensitive)
+ normalized = description.lower()
+ for doc in ref.stream():
+ existing = doc.to_dict()
+ if existing.get('deleted'):
+ continue
+ if existing.get('description', '').strip().lower() == normalized:
+ existing['id'] = doc.id
+ return _prepare_for_read(existing)
+
+ now = datetime.now(timezone.utc)
+ data['description'] = description
+ data.setdefault('created_at', now)
+ data.setdefault('updated_at', now)
+ data.setdefault('completed', False)
+
+ _, doc_ref = ref.add(data)
+ result = data.copy()
+ result['id'] = doc_ref.id
+ return result
+
+
+# --- READ ---
+
+
+def get_staged_tasks(uid: str, limit: int = 100, offset: int = 0) -> Tuple[List[dict], bool]:
+ """List staged tasks ordered by relevance_score ASC, filtering out completed/deleted.
+
+ Matches Rust behavior: completed=false filter, skip deleted, tie-break by created_at DESC.
+ Returns (items, has_more).
+ """
+ ref = db.collection('users').document(uid).collection(COLLECTION)
+ query = (
+ ref.where(filter=firestore.FieldFilter('completed', '==', False))
+ .order_by('relevance_score', direction=firestore.Query.ASCENDING)
+ .order_by('created_at', direction=firestore.Query.DESCENDING)
+ )
+
+ # Fetch more than needed to account for deleted items being filtered client-side
+ fetch_limit = (limit + 1) * 2
+ if offset > 0:
+ query = query.offset(offset)
+ query = query.limit(fetch_limit)
+
+ docs = list(query.stream())
+ items = []
+ for doc in docs:
+ data = doc.to_dict()
+ # Skip soft-deleted
+ if data.get('deleted'):
+ continue
+ data['id'] = doc.id
+ items.append(_prepare_for_read(data))
+
+ has_more = len(items) > limit
+ if has_more:
+ items = items[:limit]
+ return items, has_more
+
+
+def get_staged_task(uid: str, task_id: str) -> Optional[dict]:
+ """Get a single staged task by ID."""
+ doc = db.collection('users').document(uid).collection(COLLECTION).document(task_id).get()
+ if not doc.exists:
+ return None
+ data = doc.to_dict()
+ data['id'] = doc.id
+ return _prepare_for_read(data)
+
+
+# --- UPDATE ---
+
+
+def batch_update_scores(uid: str, scores: List[dict]) -> None:
+ """Batch update relevance_score for multiple staged tasks.
+
+ Args:
+ scores: List of {"id": str, "relevance_score": int}
+ """
+ if not scores:
+ return
+ batch = db.batch()
+ ref = db.collection('users').document(uid).collection(COLLECTION)
+ now = datetime.now(timezone.utc)
+ for item in scores:
+ doc_ref = ref.document(item['id'])
+ batch.update(doc_ref, {'relevance_score': item['relevance_score'], 'updated_at': now})
+ batch.commit()
+
+
+# --- DELETE ---
+
+
+def delete_staged_task(uid: str, task_id: str) -> None:
+ """Hard-delete a staged task. Idempotent — no error if not found (matches Rust behavior)."""
+ doc_ref = db.collection('users').document(uid).collection(COLLECTION).document(task_id)
+ doc_ref.delete()
+
+
+def delete_staged_tasks_batch(uid: str, task_ids: List[str]) -> int:
+ """Hard-delete multiple staged tasks. Returns count deleted."""
+ if not task_ids:
+ return 0
+ batch = db.batch()
+ ref = db.collection('users').document(uid).collection(COLLECTION)
+ for task_id in task_ids:
+ batch.delete(ref.document(task_id))
+ batch.commit()
+ return len(task_ids)
+
+
+# --- PROMOTE ---
+
+
+def get_active_ai_action_items(uid: str) -> List[dict]:
+ """Get active action items that were promoted from staged (from_staged=true, not completed, not deleted)."""
+ ref = db.collection('users').document(uid).collection('action_items')
+ query = ref.where(filter=firestore.FieldFilter('from_staged', '==', True)).where(
+ filter=firestore.FieldFilter('completed', '==', False)
+ )
+ items = []
+ for doc in query.stream():
+ data = doc.to_dict()
+ # Skip soft-deleted
+ if data.get('deleted'):
+ continue
+ data['id'] = doc.id
+ items.append(_prepare_for_read(data))
+ return items
+
+
+def promote_staged_task(uid: str, staged_task: dict) -> dict:
+ """Create an action item from a staged task (from_staged=true). Returns created action item."""
+ now = datetime.now(timezone.utc)
+ action_item_data = {
+ 'description': staged_task['description'],
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ 'from_staged': True,
+ 'source': staged_task.get('source'),
+ 'priority': staged_task.get('priority'),
+ 'metadata': staged_task.get('metadata'),
+ 'category': staged_task.get('category'),
+ 'relevance_score': staged_task.get('relevance_score'),
+ }
+ if staged_task.get('due_at'):
+ action_item_data['due_at'] = staged_task['due_at']
+
+ ref = db.collection('users').document(uid).collection('action_items')
+ _, doc_ref = ref.add(action_item_data)
+ action_item_data['id'] = doc_ref.id
+ return action_item_data
+
+
+# --- SCORES (daily/weekly/overall) ---
+
+
+def get_action_items_for_daily_score(uid: str, due_start: str, due_end: str) -> Tuple[int, int]:
+ """Count completed vs total action items due on a specific day.
+
+ Returns (completed_count, total_count).
+ """
+ ref = db.collection('users').document(uid).collection('action_items')
+ start_dt = datetime.fromisoformat(due_start.replace('Z', '+00:00'))
+ end_dt = datetime.fromisoformat(due_end.replace('Z', '+00:00'))
+
+ query = ref.where(filter=firestore.FieldFilter('due_at', '>=', start_dt)).where(
+ filter=firestore.FieldFilter('due_at', '<=', end_dt)
+ )
+
+ completed = 0
+ total = 0
+ for doc in query.stream():
+ data = doc.to_dict()
+ if data.get('deleted'):
+ continue
+ total += 1
+ if data.get('completed'):
+ completed += 1
+ return completed, total
+
+
+def get_action_items_for_weekly_score(uid: str, week_start: str, week_end: str) -> Tuple[int, int]:
+ """Count completed vs total action items created in a 7-day window.
+
+ Uses created_at range (not due_at) to match Rust weekly score behavior.
+ Returns (completed_count, total_count).
+ """
+ ref = db.collection('users').document(uid).collection('action_items')
+ start_dt = datetime.fromisoformat(week_start.replace('Z', '+00:00'))
+ end_dt = datetime.fromisoformat(week_end.replace('Z', '+00:00'))
+
+ query = ref.where(filter=firestore.FieldFilter('created_at', '>=', start_dt)).where(
+ filter=firestore.FieldFilter('created_at', '<=', end_dt)
+ )
+
+ completed = 0
+ total = 0
+ for doc in query.stream():
+ data = doc.to_dict()
+ if data.get('deleted'):
+ continue
+ total += 1
+ if data.get('completed'):
+ completed += 1
+ return completed, total
+
+
+def get_action_items_for_overall_score(uid: str) -> Tuple[int, int]:
+ """Count completed vs total action items (all time, not deleted).
+
+ Returns (completed_count, total_count).
+ """
+ ref = db.collection('users').document(uid).collection('action_items')
+
+ completed = 0
+ total = 0
+ for doc in ref.stream():
+ data = doc.to_dict()
+ if data.get('deleted'):
+ continue
+ total += 1
+ if data.get('completed'):
+ completed += 1
+ return completed, total
diff --git a/backend/database/users.py b/backend/database/users.py
index 0944b9eb0a..5c53c54f45 100644
--- a/backend/database/users.py
+++ b/backend/database/users.py
@@ -1050,3 +1050,79 @@ def set_user_transcription_preferences(uid: str, single_language_mode: bool = No
if update_data:
user_ref.update(update_data)
+
+
+# **************************************
+# ****** Assistant Settings ************
+# **************************************
+
+
+def get_assistant_settings(uid: str) -> dict:
+ """Get the user's assistant_settings map from their user document."""
+ user_ref = db.collection('users').document(uid)
+ user_doc = user_ref.get()
+ if user_doc.exists:
+ user_data = user_doc.to_dict()
+ settings = user_data.get('assistant_settings', {})
+ # update_channel is a top-level field, not inside assistant_settings
+ update_channel = user_data.get('update_channel')
+ if update_channel is not None:
+ settings['update_channel'] = update_channel
+ return settings
+ return {}
+
+
+def update_assistant_settings(uid: str, data: dict) -> dict:
+ """Merge-update the user's assistant_settings map. Returns merged state.
+
+ Uses per-section set(merge=True) to avoid overwriting sibling sections.
+ Each non-empty section is written individually so that e.g. patching
+ only 'focus' does not wipe 'shared' or 'task'.
+ """
+ user_ref = db.collection('users').document(uid)
+
+ # Separate update_channel (top-level) from assistant_settings sub-map
+ update_channel = data.pop('update_channel', None)
+
+ # Write each section individually with merge to preserve siblings
+ for section_key, section_val in data.items():
+ if isinstance(section_val, dict) and section_val:
+ user_ref.set({'assistant_settings': {section_key: section_val}}, merge=True)
+
+ if update_channel is not None:
+ user_ref.set({'update_channel': update_channel}, merge=True)
+
+ return get_assistant_settings(uid)
+
+
+# **************************************
+# ******** AI User Profile *************
+# **************************************
+
+
+def get_ai_user_profile(uid: str) -> Optional[dict]:
+ """Get the user's ai_user_profile map from their user document."""
+ user_ref = db.collection('users').document(uid)
+ user_doc = user_ref.get()
+ if user_doc.exists:
+ user_data = user_doc.to_dict()
+ return user_data.get('ai_user_profile')
+ return None
+
+
+def update_ai_user_profile(uid: str, data: dict) -> dict:
+ """Full-replace the user's ai_user_profile map. Returns new state.
+
+ Uses update() for true field replacement (removes stale nested keys).
+ Falls back to set(merge=True) if document doesn't exist yet.
+ """
+ user_ref = db.collection('users').document(uid)
+ try:
+ user_ref.update({'ai_user_profile': data})
+ except Exception as e:
+ # Only fall back on not-found (code 404); re-raise other errors
+ if hasattr(e, 'code') and e.code == 404:
+ user_ref.set({'ai_user_profile': data}, merge=True)
+ else:
+ raise
+ return get_ai_user_profile(uid)
diff --git a/backend/main.py b/backend/main.py
index 827cd4d34b..04be0b585b 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -45,6 +45,10 @@
announcements,
phone_calls,
agent_tools,
+ screen_activity,
+ focus_sessions,
+ advice,
+ staged_tasks,
)
from utils.other.timeout import TimeoutMiddleware
@@ -104,6 +108,10 @@
app.include_router(announcements.router)
app.include_router(phone_calls.router)
app.include_router(agent_tools.router)
+app.include_router(screen_activity.router)
+app.include_router(focus_sessions.router)
+app.include_router(advice.router)
+app.include_router(staged_tasks.router)
methods_timeout = {
diff --git a/backend/models/message_event.py b/backend/models/message_event.py
index bddbeb2a27..b422c767b1 100644
--- a/backend/models/message_event.py
+++ b/backend/models/message_event.py
@@ -181,3 +181,102 @@ def to_json(self):
j["type"] = self.event_type
del j["event_type"]
return j
+
+
+# Desktop proactive AI events (Phase 2 — #5396)
+
+
+class FocusResultEvent(MessageEvent):
+ event_type: str = "focus_result"
+ frame_id: str
+ status: str
+ app_or_site: str
+ description: str
+ message: Optional[str] = None
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class TasksExtractedEvent(MessageEvent):
+ event_type: str = "tasks_extracted"
+ frame_id: str
+ tasks: List = []
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class MemoriesExtractedEvent(MessageEvent):
+ event_type: str = "memories_extracted"
+ frame_id: str
+ memories: List = []
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class AdviceExtractedEvent(MessageEvent):
+ event_type: str = "advice_extracted"
+ frame_id: str
+ advice: Optional[Any] = None
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class LiveNoteEvent(MessageEvent):
+ event_type: str = "live_note"
+ text: str
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class ProfileUpdatedEvent(MessageEvent):
+ event_type: str = "profile_updated"
+ profile_text: str
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class RerankCompleteEvent(MessageEvent):
+ event_type: str = "rerank_complete"
+ updated_tasks: List = []
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
+
+
+class DedupCompleteEvent(MessageEvent):
+ event_type: str = "dedup_complete"
+ deleted_ids: List = []
+ reason: str = ""
+
+ def to_json(self):
+ j = self.model_dump(mode="json")
+ j["type"] = self.event_type
+ del j["event_type"]
+ return j
diff --git a/backend/routers/advice.py b/backend/routers/advice.py
new file mode 100644
index 0000000000..72d671c84d
--- /dev/null
+++ b/backend/routers/advice.py
@@ -0,0 +1,138 @@
+import logging
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel, Field
+
+import database.advice as advice_db
+from utils.other import endpoints as auth
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+VALID_CATEGORIES = ('productivity', 'health', 'communication', 'learning', 'other')
+
+
+class CreateAdviceRequest(BaseModel):
+ content: str = Field(description="Advice content text")
+ category: Optional[str] = Field(default=None, description="Category: productivity, health, communication, learning, other")
+ reasoning: Optional[str] = Field(default=None, description="Reasoning behind the advice")
+ source_app: Optional[str] = Field(default=None, description="App where context was observed")
+ confidence: Optional[float] = Field(default=None, description="Confidence score 0.0-1.0")
+ context_summary: Optional[str] = Field(default=None, description="Context summary")
+ current_activity: Optional[str] = Field(default=None, description="User's current activity")
+
+
+class UpdateAdviceRequest(BaseModel):
+ is_read: Optional[bool] = None
+ is_dismissed: Optional[bool] = None
+
+
+class AdviceResponse(BaseModel):
+ id: str
+ content: str
+ category: str = 'other'
+ reasoning: Optional[str] = None
+ source_app: Optional[str] = None
+ confidence: float = 0.5
+ context_summary: Optional[str] = None
+ current_activity: Optional[str] = None
+ created_at: object = None
+ updated_at: object = None
+ is_read: bool = False
+ is_dismissed: bool = False
+
+
+class AdviceStatusResponse(BaseModel):
+ status: str
+
+
+def _validate_category(category: Optional[str]):
+ if category and category not in VALID_CATEGORIES:
+ raise HTTPException(
+ status_code=400,
+ detail=f"category must be one of: {', '.join(VALID_CATEGORIES)}"
+ )
+
+
+def _validate_confidence(confidence: Optional[float]):
+ if confidence is not None and not (0.0 <= confidence <= 1.0):
+ raise HTTPException(status_code=400, detail="confidence must be between 0.0 and 1.0")
+
+
+@router.post('/v1/advice', tags=['advice'])
+def create_advice(
+ request: CreateAdviceRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ _validate_category(request.category)
+ _validate_confidence(request.confidence)
+ try:
+ return advice_db.create_advice(uid, request.model_dump(exclude_none=True))
+ except Exception:
+ logger.exception('Failed to create advice for uid=%s', uid)
+ raise HTTPException(status_code=500, detail="Failed to create advice")
+
+
+@router.get('/v1/advice', tags=['advice'])
+def get_advice(
+ limit: int = Query(default=100, ge=1, le=1000),
+ offset: int = Query(default=0, ge=0),
+ category: Optional[str] = Query(default=None),
+ include_dismissed: bool = Query(default=False),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ if category and category not in VALID_CATEGORIES:
+ category = None # Skip unknown category filter (match Rust behavior)
+ try:
+ return advice_db.get_advice(
+ uid, limit=limit, offset=offset, category=category, include_dismissed=include_dismissed,
+ )
+ except Exception:
+ logger.exception('Failed to get advice for uid=%s', uid)
+ return []
+
+
+@router.patch('/v1/advice/{advice_id}', tags=['advice'])
+def update_advice(
+ advice_id: str,
+ request: UpdateAdviceRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ update_data = request.model_dump(exclude_none=True)
+ try:
+ result = advice_db.update_advice(uid, advice_id, update_data)
+ if result is None:
+ raise HTTPException(status_code=500, detail="Failed to update advice")
+ return result
+ except HTTPException:
+ raise
+ except Exception:
+ logger.exception('Failed to update advice %s for uid=%s', advice_id, uid)
+ raise HTTPException(status_code=500, detail="Failed to update advice")
+
+
+@router.delete('/v1/advice/{advice_id}', response_model=AdviceStatusResponse, tags=['advice'])
+def delete_advice(
+ advice_id: str,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ try:
+ advice_db.delete_advice(uid, advice_id)
+ return AdviceStatusResponse(status="ok")
+ except Exception:
+ logger.exception('Failed to delete advice %s for uid=%s', advice_id, uid)
+ raise HTTPException(status_code=500, detail="Failed to delete advice")
+
+
+@router.post('/v1/advice/mark-all-read', response_model=AdviceStatusResponse, tags=['advice'])
+def mark_all_advice_read(
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ try:
+ count = advice_db.mark_all_advice_read(uid)
+ return AdviceStatusResponse(status=f"marked {count} as read")
+ except Exception:
+ logger.exception('Failed to mark all advice read for uid=%s', uid)
+ raise HTTPException(status_code=500, detail="Failed to mark advice as read")
diff --git a/backend/routers/agent_tools.py b/backend/routers/agent_tools.py
index 0fcc3b7f55..805e98c4f2 100644
--- a/backend/routers/agent_tools.py
+++ b/backend/routers/agent_tools.py
@@ -4,13 +4,16 @@
Endpoints:
- GET /v1/agent/tools — returns tool definitions (name, description, parameters)
- POST /v1/agent/execute-tool — executes a named tool and returns the result
-- GET /v1/agent/vm-status — returns basic VM status from Firestore
-- POST /v1/agent/vm-ensure — checks VM status, restarts if stopped, returns current state
+- GET /v1/agent/vm-status — returns VM status from Firestore (with restart if stopped)
+- POST /v1/agent/vm-ensure — ensures user has a VM: creates if missing, restarts if stopped
- POST /v1/agent/keepalive — pings the VM to reset its idle auto-stop timer
"""
import asyncio
import logging
+import os
+import time
+import uuid
from datetime import datetime, timezone
import google.auth
@@ -19,6 +22,7 @@
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel
+from database._client import db as firestore_db
from database.users import get_agent_vm
from utils.other.endpoints import get_current_user_uid
from utils.retrieval.agentic import agent_config_context, CORE_TOOLS
@@ -29,7 +33,10 @@
router = APIRouter()
-GCE_PROJECT = "based-hardware"
+GCE_PROJECT = os.environ.get("GCE_PROJECT_ID", os.environ.get("GOOGLE_CLOUD_PROJECT", "based-hardware"))
+GCE_ZONE = "us-central1-a"
+GCE_SOURCE_IMAGE = os.environ.get("GCE_SOURCE_IMAGE", f"projects/{GCE_PROJECT}/global/images/family/omi-agent")
+AGENT_GCS_BUCKET = os.environ.get("AGENT_GCS_BUCKET", "based-hardware-agent")
# --------------- GCE helpers ---------------
@@ -56,8 +63,6 @@ async def _check_gce_status(vm_name: str, zone: str) -> str:
async def _start_vm_and_wait(vm_name: str, zone: str) -> str:
"""Start a stopped/terminated GCE VM and wait for it to get an IP. Returns the new IP."""
- import time
-
t0 = time.monotonic()
token = _get_gce_access_token()
start_url = (
@@ -120,14 +125,124 @@ async def _start_vm_and_wait(vm_name: str, zone: str) -> str:
def _update_firestore_vm(uid: str, ip: str | None, status: str):
"""Update the user's agentVm fields in Firestore."""
- from database.users import db as firestore_db
-
update = {"agentVm.status": status}
if ip:
update["agentVm.ip"] = ip
firestore_db.collection('users').document(uid).update(update)
+def _set_firestore_vm(uid: str, vm_name: str, zone: str, ip: str | None, status: str, auth_token: str):
+ """Write the full agentVm document to Firestore (for initial provisioning)."""
+ now = datetime.now(timezone.utc).isoformat()
+ vm_data = {
+ "vmName": vm_name,
+ "zone": zone,
+ "status": status,
+ "authToken": auth_token,
+ "createdAt": now,
+ }
+ if ip:
+ vm_data["ip"] = ip
+ firestore_db.collection('users').document(uid).set({"agentVm": vm_data}, merge=True)
+
+
+async def _create_gce_vm(vm_name: str, auth_token: str) -> str:
+ """Create a GCE VM from the omi-agent image family. Returns the external IP."""
+ zone = GCE_ZONE
+ startup_script = (
+ f"#!/bin/bash\ncurl -sf https://storage.googleapis.com/{AGENT_GCS_BUCKET}/startup.sh"
+ f" -o /tmp/omi-startup.sh && bash /tmp/omi-startup.sh\n"
+ )
+
+ url = f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/instances"
+ body = {
+ "name": vm_name,
+ "machineType": f"zones/{zone}/machineTypes/e2-small",
+ "disks": [
+ {
+ "boot": True,
+ "autoDelete": True,
+ "initializeParams": {
+ "sourceImage": GCE_SOURCE_IMAGE,
+ "diskSizeGb": "50",
+ "diskType": f"zones/{zone}/diskTypes/pd-ssd",
+ },
+ }
+ ],
+ "networkInterfaces": [
+ {
+ "network": "global/networks/default",
+ "accessConfigs": [{"type": "ONE_TO_ONE_NAT", "name": "External NAT"}],
+ }
+ ],
+ "tags": {"items": ["omi-agent-vm"]},
+ "metadata": {
+ "items": [
+ {"key": "startup-script", "value": startup_script},
+ {"key": "auth-token", "value": auth_token},
+ ]
+ },
+ }
+
+ token = _get_gce_access_token()
+ async with httpx.AsyncClient(timeout=180) as client:
+ resp = await client.post(url, headers={"Authorization": f"Bearer {token}"}, json=body)
+ if resp.status_code not in (200, 204):
+ raise Exception(f"GCE insert failed: {resp.status_code} {sanitize(resp.text)}")
+
+ op_name = resp.json().get("name")
+ if not op_name:
+ raise Exception("Missing operation name in GCE insert response")
+
+ # Poll operation until done (max ~2 minutes)
+ op_url = f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/operations/{op_name}"
+ op_done = False
+ for i in range(24):
+ await asyncio.sleep(5)
+ token = _get_gce_access_token()
+ status_resp = await client.get(op_url, headers={"Authorization": f"Bearer {token}"})
+ op_status = status_resp.json()
+ if op_status.get("status") == "DONE":
+ if "error" in op_status:
+ raise Exception(f"GCE insert operation failed: {op_status['error']}")
+ logger.info(f"[vm-create] {vm_name} operation done after {i + 1} polls")
+ op_done = True
+ break
+ if not op_done:
+ raise Exception(f"GCE insert timed out after 120s for {vm_name}")
+
+ # Get external IP
+ instance_url = (
+ f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/instances/{vm_name}"
+ )
+ for attempt in range(6):
+ token = _get_gce_access_token()
+ inst_resp = await client.get(instance_url, headers={"Authorization": f"Bearer {token}"})
+ instance = inst_resp.json()
+ try:
+ candidate = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"]
+ if candidate:
+ logger.info(f"[vm-create] {vm_name} got IP {candidate} on attempt {attempt + 1}")
+ return candidate
+ except (KeyError, IndexError):
+ pass
+ if attempt < 5:
+ await asyncio.sleep(3)
+
+ raise Exception(f"Failed to get external IP for {vm_name} after 6 attempts")
+
+
+async def _provision_vm_background(uid: str, vm_name: str, auth_token: str):
+ """Background task: create a new GCE VM, update Firestore when ready."""
+ try:
+ ip = await _create_gce_vm(vm_name, auth_token)
+ _set_firestore_vm(uid, vm_name, GCE_ZONE, ip, "ready", auth_token)
+ logger.info(f"[vm-ensure] VM {vm_name} created, ip={ip}")
+ except Exception as e:
+ logger.error(f"[vm-ensure] Failed to create VM {vm_name}: {e}")
+ _update_firestore_vm(uid, None, "error")
+
+
async def _restart_vm_background(uid: str, vm_name: str, zone: str):
"""Background task: start stopped VM, update Firestore with new IP when ready."""
try:
@@ -142,53 +257,106 @@ async def _restart_vm_background(uid: str, vm_name: str, zone: str):
# --------------- endpoints ---------------
-@router.get("/v1/agent/vm-status")
-def get_vm_status(uid: str = Depends(get_current_user_uid)):
- """Return the user's agent VM info from Firestore."""
- vm = get_agent_vm(uid)
- logger.info(f"[vm-status] uid={uid} vm={sanitize(vm)}")
- if not vm or vm.get("status") != "ready":
- return {"has_vm": False}
+def _vm_response(vm: dict, status_override: str | None = None) -> dict:
+ """Build a standard VM response dict with all fields the desktop expects."""
return {
"has_vm": True,
- "status": vm.get("status"),
+ "status": status_override or vm.get("status"),
+ "vm_name": vm.get("vmName"),
+ "ip": vm.get("ip"),
+ "auth_token": vm.get("authToken"),
+ "zone": vm.get("zone", GCE_ZONE),
+ "created_at": vm.get("createdAt"),
+ "last_query_at": vm.get("lastQueryAt"),
}
+@router.get("/v1/agent/vm-status")
+async def get_vm_status(background_tasks: BackgroundTasks, uid: str = Depends(get_current_user_uid)):
+ """Return the user's agent VM info from Firestore. Restarts stopped VMs."""
+ vm = get_agent_vm(uid)
+ if not vm:
+ return {"has_vm": False}
+
+ fs_status = vm.get("status", "")
+ vm_name = vm.get("vmName")
+ zone = vm.get("zone", GCE_ZONE)
+
+ # For ready/error/stopped VMs, verify actual GCE status and restart if needed
+ if fs_status in ("ready", "error", "stopped") and vm_name:
+ try:
+ gce_status = await _check_gce_status(vm_name, zone)
+ except Exception as e:
+ logger.warning(f"[vm-status] GCE status check failed for {vm_name}: {e}")
+ return _vm_response(vm)
+
+ if gce_status in ("TERMINATED", "STOPPED"):
+ logger.info(f"[vm-status] VM {vm_name} is {gce_status}, restarting...")
+ _update_firestore_vm(uid, None, "provisioning")
+ background_tasks.add_task(_restart_vm_background, uid, vm_name, zone)
+ return _vm_response(vm, status_override="provisioning")
+
+ if gce_status == "RUNNING" and fs_status != "ready":
+ _update_firestore_vm(uid, vm.get("ip"), "ready")
+ return _vm_response(vm, status_override="ready")
+
+ return _vm_response(vm)
+
+
@router.post("/v1/agent/vm-ensure")
async def ensure_vm(background_tasks: BackgroundTasks, uid: str = Depends(get_current_user_uid)):
- """Check VM status; if stopped/terminated, restart it in the background."""
+ """Ensure user has a VM: create if missing, restart if stopped."""
vm = get_agent_vm(uid)
+
+ # No VM exists — provision a new one
if not vm:
- return {"has_vm": False}
+ uid_prefix = uid[:12].lower() if len(uid) > 12 else uid.lower()
+ vm_name = f"omi-agent-{uid_prefix}"
+ auth_token = f"omi-{uuid.uuid4()}"
+
+ # Claim the slot in Firestore before spawning background creation
+ _set_firestore_vm(uid, vm_name, GCE_ZONE, None, "provisioning", auth_token)
+ background_tasks.add_task(_provision_vm_background, uid, vm_name, auth_token)
+ logger.info(f"[vm-ensure] Provisioning new VM {vm_name} for uid={uid[:8]}...")
+
+ return {
+ "has_vm": True,
+ "status": "provisioning",
+ "vm_name": vm_name,
+ "ip": None,
+ "auth_token": auth_token,
+ "zone": GCE_ZONE,
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ "last_query_at": None,
+ }
vm_name = vm.get("vmName")
- zone = vm.get("zone", "us-central1-a")
+ zone = vm.get("zone", GCE_ZONE)
fs_status = vm.get("status", "")
# If Firestore already says provisioning, don't double-start
if fs_status == "provisioning":
- return {"has_vm": True, "status": "provisioning"}
+ return _vm_response(vm)
# Check actual GCE status for ready/error/stopped VMs
- if fs_status in ("ready", "error", "stopped"):
+ if fs_status in ("ready", "error", "stopped") and vm_name:
try:
gce_status = await _check_gce_status(vm_name, zone)
except Exception as e:
logger.error(f"[vm-ensure] GCE status check failed: {e}")
- return {"has_vm": True, "status": fs_status}
+ return _vm_response(vm)
if gce_status in ("TERMINATED", "STOPPED"):
logger.info(f"[vm-ensure] VM {vm_name} is {gce_status}, restarting...")
_update_firestore_vm(uid, None, "provisioning")
background_tasks.add_task(_restart_vm_background, uid, vm_name, zone)
- return {"has_vm": True, "status": "provisioning"}
+ return _vm_response(vm, status_override="provisioning")
if gce_status == "RUNNING" and fs_status != "ready":
_update_firestore_vm(uid, vm.get("ip"), "ready")
- return {"has_vm": True, "status": "ready"}
+ return _vm_response(vm, status_override="ready")
- return {"has_vm": True, "status": fs_status}
+ return _vm_response(vm)
@router.post("/v1/agent/keepalive")
diff --git a/backend/routers/auth.py b/backend/routers/auth.py
index 3e81cbd93b..aa626aa145 100644
--- a/backend/routers/auth.py
+++ b/backend/routers/auth.py
@@ -3,6 +3,7 @@
import json
import hashlib
import time
+import base64
import requests
import jwt
from typing import Optional
@@ -15,7 +16,7 @@
import pathlib
import firebase_admin.auth
from database.redis_db import set_auth_session, get_auth_session, set_auth_code, get_auth_code, delete_auth_code
-from utils.log_sanitizer import sanitize
+from utils.log_sanitizer import sanitize, sanitize_pii
import logging
logger = logging.getLogger(__name__)
@@ -44,6 +45,11 @@ async def auth_authorize(
if provider not in ['google', 'apple']:
raise HTTPException(status_code=400, detail="Unsupported provider")
+ # Validate redirect_uri against allowed app URL schemes
+ ALLOWED_REDIRECT_SCHEMES = ('omi://', 'omi-computer://', 'omi-computer-dev://')
+ if not redirect_uri or not any(redirect_uri.startswith(s) for s in ALLOWED_REDIRECT_SCHEMES):
+ raise HTTPException(status_code=400, detail="Invalid redirect_uri: must use an allowed app URL scheme")
+
# Store session for auth flow
session_id = str(uuid.uuid4())
session_data = {
@@ -96,6 +102,7 @@ async def auth_callback_google(
"request": request,
"code": auth_code,
"state": session_data['state'] or '',
+ "redirect_uri": session_data.get('redirect_uri') or 'omi://auth/callback',
},
)
@@ -134,6 +141,7 @@ async def auth_callback_apple_post(
"request": request,
"code": auth_code,
"state": session_data['state'] or '',
+ "redirect_uri": session_data.get('redirect_uri') or 'omi://auth/callback',
},
)
@@ -379,47 +387,63 @@ async def _generate_custom_token(provider: str, id_token: str, access_token: str
Works with any bundle ID - perfect for multiple developers
"""
try:
- # Get Firebase API Key from environment
- firebase_api_key = os.getenv('FIREBASE_API_KEY')
- if not firebase_api_key:
- raise Exception("FIREBASE_API_KEY not configured")
-
- # Sign in with OAuth credential using Firebase Auth REST API
- sign_in_url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp?key={firebase_api_key}"
-
- # Prepare the postBody based on provider
- if provider == 'google':
- post_body = f'id_token={id_token}&providerId=google.com'
- if access_token:
- post_body += f'&access_token={access_token}'
- elif provider == 'apple':
- post_body = f'id_token={id_token}&providerId=apple.com'
- if access_token:
- post_body += f'&access_token={access_token}'
- else:
- raise Exception(f"Unsupported provider: {provider}")
-
- payload = {
- 'postBody': post_body,
- 'requestUri': 'http://localhost',
- 'returnIdpCredential': True,
- 'returnSecureToken': True,
- }
+ firebase_uid = None
- # Call Firebase Auth REST API to sign in
- response = requests.post(sign_in_url, json=payload)
-
- if response.status_code != 200:
- logger.error(f"Firebase sign-in failed: {sanitize(response.text)}")
- raise Exception(f"Firebase sign-in failed: status={response.status_code}")
+ # Try REST API first (works when FIREBASE_API_KEY has no app restrictions)
+ firebase_api_key = os.getenv('FIREBASE_API_KEY')
+ if firebase_api_key:
+ sign_in_url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp?key={firebase_api_key}"
+
+ if provider == 'google':
+ post_body = f'id_token={id_token}&providerId=google.com'
+ if access_token:
+ post_body += f'&access_token={access_token}'
+ elif provider == 'apple':
+ post_body = f'id_token={id_token}&providerId=apple.com'
+ if access_token:
+ post_body += f'&access_token={access_token}'
+ else:
+ raise Exception(f"Unsupported provider: {provider}")
+
+ payload = {
+ 'postBody': post_body,
+ 'requestUri': 'http://localhost',
+ 'returnIdpCredential': True,
+ 'returnSecureToken': True,
+ }
+
+ response = requests.post(sign_in_url, json=payload)
+
+ if response.status_code == 200:
+ result = response.json()
+ firebase_uid = result.get('localId')
+ if firebase_uid:
+ logger.info(f"Firebase sign-in successful for {provider}, UID: {firebase_uid}")
+ else:
+ logger.warning(
+ f"Firebase REST API sign-in failed (status={response.status_code}), falling back to Admin SDK"
+ )
+
+ # Fallback: verify id_token via Admin SDK and look up/create user
+ if not firebase_uid:
+ verified_token = firebase_admin.auth.verify_id_token(id_token)
+ email = verified_token.get('email')
+ if not email:
+ raise Exception("No email in verified id_token")
- result = response.json()
- firebase_uid = result.get('localId')
+ # Look up existing Firebase user by email
+ try:
+ user = firebase_admin.auth.get_user_by_email(email)
+ firebase_uid = user.uid
+ logger.info(f"Found existing Firebase user for {sanitize_pii(email)}, UID: {firebase_uid}")
+ except firebase_admin.auth.UserNotFoundError:
+ # Create new Firebase user
+ user = firebase_admin.auth.create_user(email=email, email_verified=True)
+ firebase_uid = user.uid
+ logger.info(f"Created new Firebase user for {sanitize_pii(email)}, UID: {firebase_uid}")
if not firebase_uid:
- raise Exception("No Firebase UID returned from sign-in")
-
- logger.info(f"Firebase sign-in successful for {provider}, UID: {firebase_uid}")
+ raise Exception("No Firebase UID obtained")
# Create custom token for this UID
custom_token = firebase_admin.auth.create_custom_token(firebase_uid)
diff --git a/backend/routers/chat.py b/backend/routers/chat.py
index a4a748ef4d..15d34e920c 100644
--- a/backend/routers/chat.py
+++ b/backend/routers/chat.py
@@ -6,8 +6,9 @@
from typing import List, Optional
from pathlib import Path
-from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
+from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, Form
from fastapi.responses import StreamingResponse
+from pydantic import BaseModel
from multipart.multipart import shutil
import database.chat as chat_db
@@ -32,6 +33,7 @@
resolve_voice_message_language,
transcribe_voice_message_segment,
)
+from utils.llm.clients import llm_mini
from utils.llm.persona import initial_persona_chat_message
from utils.llm.chat import initial_chat_message
from utils.llm.goals import extract_and_update_goal_progress
@@ -498,6 +500,246 @@ def upload_file_chat(files: List[UploadFile] = File(...), uid: str = Depends(aut
return response
+# ---------------------------------------------------------------------------
+# Desktop: session management, message persistence, and rating
+# The desktop app manages sessions explicitly (vs mobile's implicit sessions)
+# and persists messages without triggering the LLM pipeline — AI responses
+# come from the local ACP Bridge, not the backend.
+# ---------------------------------------------------------------------------
+
+
+class CreateChatSessionRequest(BaseModel):
+ title: Optional[str] = None
+ app_id: Optional[str] = None
+
+
+class UpdateChatSessionRequest(BaseModel):
+ title: Optional[str] = None
+ starred: Optional[bool] = None
+
+
+class ChatSessionResponse(BaseModel):
+ id: str
+ title: str
+ preview: Optional[str] = None
+ created_at: datetime
+ updated_at: datetime
+ app_id: Optional[str] = None
+ message_count: int = 0
+ starred: bool = False
+
+
+class SaveMessageRequest(BaseModel):
+ text: str
+ sender: str
+ app_id: Optional[str] = None
+ session_id: Optional[str] = None
+ metadata: Optional[str] = None
+
+
+class SaveMessageResponse(BaseModel):
+ id: str
+ created_at: datetime
+
+
+class RateMessageRequest(BaseModel):
+ rating: Optional[int] = None
+
+
+class StatusResponse(BaseModel):
+ status: str
+
+
+@router.get('/v2/chat-sessions', response_model=List[ChatSessionResponse], tags=['chat'])
+def list_chat_sessions(
+ app_id: Optional[str] = Query(None),
+ limit: int = Query(50, ge=1, le=200),
+ offset: int = Query(0, ge=0),
+ starred: Optional[bool] = Query(None),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: list chat sessions with optional filtering."""
+ sessions = chat_db.get_chat_sessions(uid, app_id=app_id, limit=limit, offset=offset, starred=starred)
+ return sessions
+
+
+@router.post('/v2/chat-sessions', response_model=ChatSessionResponse, tags=['chat'])
+def create_chat_session(
+ request: CreateChatSessionRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: explicitly create a named chat session."""
+ now = datetime.now(timezone.utc)
+ session_data = {
+ 'id': str(uuid.uuid4()),
+ 'title': request.title or 'New Chat',
+ 'preview': None,
+ 'created_at': now,
+ 'updated_at': now,
+ 'app_id': request.app_id,
+ 'plugin_id': request.app_id,
+ 'message_count': 0,
+ 'starred': False,
+ }
+ chat_db.add_chat_session(uid, session_data)
+ return session_data
+
+
+@router.get('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['chat'])
+def get_chat_session_by_id(
+ session_id: str,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: get a single chat session by ID."""
+ session = chat_db.get_chat_session_by_id(uid, session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail="Chat session not found")
+ return session
+
+
+@router.patch('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['chat'])
+def update_chat_session(
+ session_id: str,
+ request: UpdateChatSessionRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: update session title or starred status."""
+ session = chat_db.get_chat_session_by_id(uid, session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail="Chat session not found")
+
+ update_data = {}
+ if request.title is not None:
+ update_data['title'] = request.title
+ if request.starred is not None:
+ update_data['starred'] = request.starred
+ if update_data:
+ update_data['updated_at'] = datetime.now(timezone.utc)
+ chat_db.update_chat_session(uid, session_id, update_data)
+ session.update(update_data)
+
+ return session
+
+
+@router.delete('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['chat'])
+def delete_chat_session(
+ session_id: str,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: delete a chat session and cascade-delete its messages."""
+ session = chat_db.get_chat_session_by_id(uid, session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail="Chat session not found")
+
+ chat_db.delete_chat_session_messages(uid, session_id)
+ chat_db.delete_chat_session(uid, session_id)
+ return StatusResponse(status='ok')
+
+
+@router.post('/v2/messages/save', response_model=SaveMessageResponse, tags=['chat'])
+def save_message(
+ request: SaveMessageRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: persist a message without triggering LLM pipeline.
+
+ The desktop app runs AI locally via ACP Bridge and only calls this
+ endpoint to sync human + AI messages to Firestore.
+ """
+ if not request.text or not request.text.strip():
+ raise HTTPException(status_code=422, detail="Message text cannot be empty")
+ if request.sender not in ('human', 'ai'):
+ raise HTTPException(status_code=422, detail="sender must be 'human' or 'ai'")
+
+ now = datetime.now(timezone.utc)
+ message_id = str(uuid.uuid4())
+ message_data = {
+ 'id': message_id,
+ 'text': request.text,
+ 'created_at': now,
+ 'sender': request.sender,
+ 'app_id': request.app_id,
+ 'plugin_id': request.app_id,
+ 'session_id': request.session_id,
+ 'chat_session_id': request.session_id,
+ 'rating': None,
+ 'reported': False,
+ 'type': 'text',
+ 'memories_id': [],
+ 'from_external_integration': False,
+ 'metadata': request.metadata,
+ }
+ chat_db.save_message(uid, message_data)
+
+ if request.session_id:
+ try:
+ chat_db.add_message_to_chat_session(uid, request.session_id, message_id, preview=request.text[:200])
+ except Exception as e:
+ logger.warning(f"Failed to link message to session {request.session_id}: {e}")
+
+ return SaveMessageResponse(id=message_id, created_at=now)
+
+
+@router.patch('/v2/messages/{message_id}/rating', response_model=StatusResponse, tags=['chat'])
+def rate_message(
+ message_id: str,
+ request: RateMessageRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: rate a message (1 = thumbs up, -1 = thumbs down, null = clear)."""
+ if request.rating is not None and request.rating not in (1, -1):
+ raise HTTPException(status_code=422, detail="rating must be 1, -1, or null")
+
+ success = chat_db.update_message_rating(uid, message_id, request.rating)
+ if not success:
+ raise HTTPException(status_code=404, detail="Message not found")
+ return StatusResponse(status='ok')
+
+
+class TitleMessageInput(BaseModel):
+ text: str
+ sender: str
+
+
+class GenerateTitleRequest(BaseModel):
+ session_id: str
+ messages: List[TitleMessageInput]
+
+
+class GenerateTitleResponse(BaseModel):
+ title: str
+
+
+@router.post('/v2/chat/generate-title', response_model=GenerateTitleResponse, tags=['chat'])
+def generate_chat_title(
+ request: GenerateTitleRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Desktop: generate a short title for a chat session from its messages."""
+ if not request.messages:
+ raise HTTPException(status_code=400, detail="messages list cannot be empty")
+
+ transcript = '\n'.join(f'{m.sender}: {m.text[:500]}' for m in request.messages[:10])
+ prompt = (
+ 'Generate a short chat session title (max 6 words) summarising this conversation. '
+ 'Return ONLY the title text, no quotes or punctuation.\n\n' + transcript
+ )
+ try:
+ result = llm_mini.invoke(prompt)
+ title = result.content.strip().strip('"\'')[:100]
+ except Exception as e:
+ logger.warning(f'generate_chat_title LLM failed: {e}')
+ title = request.messages[0].text[:50]
+
+ # Update session title if session exists
+ try:
+ chat_db.update_chat_session(uid, request.session_id, {'title': title, 'updated_at': datetime.now(timezone.utc)})
+ except Exception as e:
+ logger.warning(f'generate_chat_title update session failed: {e}')
+
+ return GenerateTitleResponse(title=title)
+
+
# CLEANUP: Remove after new app goes to prod ----------------------------------------------------------
diff --git a/backend/routers/conversations.py b/backend/routers/conversations.py
index 73394980f3..f58fe6e386 100644
--- a/backend/routers/conversations.py
+++ b/backend/routers/conversations.py
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from typing import Optional, List
-from datetime import datetime, timezone
+from datetime import datetime, timezone, timedelta
import database.conversations as conversations_db
import database.action_items as action_items_db
@@ -12,8 +12,10 @@
CalendarMeetingContext,
Conversation,
ConversationPhoto,
+ ConversationSource,
ConversationStatus,
ConversationVisibility,
+ CreateConversation,
CreateConversationResponse,
Geolocation,
MergeConversationsRequest,
@@ -90,6 +92,98 @@ def process_in_progress_conversation(
return CreateConversationResponse(conversation=conversation, messages=messages)
+class FromSegmentsTranscriptSegment(BaseModel):
+ text: str
+ speaker: Optional[str] = 'SPEAKER_00'
+ speaker_id: Optional[int] = None
+ is_user: bool = False
+ person_id: Optional[str] = None
+ start: float
+ end: float
+
+
+class CreateConversationFromSegmentsRequest(BaseModel):
+ transcript_segments: List[FromSegmentsTranscriptSegment]
+ source: Optional[ConversationSource] = ConversationSource.desktop
+ started_at: Optional[datetime] = None
+ finished_at: Optional[datetime] = None
+ language: Optional[str] = 'en'
+ geolocation: Optional[Geolocation] = None
+
+
+class FromSegmentsResponse(BaseModel):
+ id: str
+ status: str
+ discarded: bool
+
+
+@router.post("/v1/conversations/from-segments", response_model=FromSegmentsResponse, tags=['conversations'])
+def create_conversation_from_segments(
+ request: CreateConversationFromSegmentsRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ if not request.transcript_segments:
+ raise HTTPException(status_code=422, detail="transcript_segments cannot be empty")
+
+ if len(request.transcript_segments) > 500:
+ raise HTTPException(status_code=422, detail="Maximum 500 transcript segments allowed")
+
+ for idx, segment in enumerate(request.transcript_segments):
+ if segment.end <= segment.start:
+ raise HTTPException(status_code=422, detail=f"Segment {idx}: end time must be after start time")
+ if segment.start < 0:
+ raise HTTPException(status_code=422, detail=f"Segment {idx}: start time cannot be negative")
+ if not segment.text or len(segment.text.strip()) == 0:
+ raise HTTPException(status_code=422, detail=f"Segment {idx}: text cannot be empty")
+
+ transcript_segments = [
+ TranscriptSegment(
+ text=seg.text.strip(),
+ speaker=seg.speaker or 'SPEAKER_00',
+ speaker_id=seg.speaker_id,
+ is_user=seg.is_user,
+ person_id=seg.person_id,
+ start=seg.start,
+ end=seg.end,
+ )
+ for seg in request.transcript_segments
+ ]
+
+ started_at = request.started_at or datetime.now(timezone.utc)
+ if request.finished_at is not None:
+ finished_at = request.finished_at
+ else:
+ last_segment = request.transcript_segments[-1]
+ finished_at = started_at + timedelta(seconds=last_segment.end)
+
+ if finished_at <= started_at:
+ raise HTTPException(status_code=422, detail="finished_at must be after started_at")
+
+ geolocation = request.geolocation
+ if geolocation and not geolocation.google_place_id:
+ try:
+ geolocation = get_google_maps_location(geolocation.latitude, geolocation.longitude)
+ except Exception as e:
+ logger.error(f"Error enriching geolocation: {e}")
+
+ create_conversation_obj = CreateConversation(
+ transcript_segments=transcript_segments,
+ started_at=started_at,
+ finished_at=finished_at,
+ language=request.language or 'en',
+ geolocation=geolocation,
+ source=request.source or ConversationSource.desktop,
+ )
+
+ conversation = process_conversation(uid, request.language or 'en', create_conversation_obj)
+
+ return FromSegmentsResponse(
+ id=conversation.id,
+ status=conversation.status.value if conversation.status else 'completed',
+ discarded=conversation.discarded,
+ )
+
+
@router.post('/v1/conversations/{conversation_id}/reprocess', response_model=Conversation, tags=['conversations'])
def reprocess_conversation(
conversation_id: str,
@@ -155,6 +249,23 @@ def get_conversations(
return conversations
+@router.get('/v1/conversations/count', tags=['conversations'])
+def get_conversations_count(
+ statuses: Optional[str] = Query("processing,completed"),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Count conversations matching optional status filters."""
+ status_list = [s.strip() for s in statuses.split(',') if s.strip()] if statuses else []
+ if len(status_list) > 10:
+ raise HTTPException(status_code=400, detail="Too many status values (max 10)")
+ try:
+ count = conversations_db.count_conversations(uid, statuses=status_list)
+ except Exception as e:
+ logger.warning(f'count_conversations aggregation fallback: {e}')
+ count = sum(1 for _ in conversations_db.stream_conversations(uid, statuses=status_list))
+ return {'count': count}
+
+
@router.get("/v1/conversations/{conversation_id}", response_model=Conversation, tags=['conversations'])
def get_conversation_by_id(conversation_id: str, uid: str = Depends(auth.get_current_user_uid)):
logger.info(f'get_conversation_by_id {uid} {conversation_id}')
diff --git a/backend/routers/focus_sessions.py b/backend/routers/focus_sessions.py
new file mode 100644
index 0000000000..db19622e86
--- /dev/null
+++ b/backend/routers/focus_sessions.py
@@ -0,0 +1,155 @@
+import logging
+from collections import defaultdict
+from datetime import datetime, timezone
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel, Field
+
+import database.focus_sessions as focus_sessions_db
+from utils.other import endpoints as auth
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+class CreateFocusSessionRequest(BaseModel):
+ status: str = Field(description="'focused' or 'distracted'")
+ app_or_site: str = Field(description="App or website name")
+ description: str = Field(description="Brief description of the session")
+ message: Optional[str] = Field(default=None, description="Optional coaching message")
+ duration_seconds: Optional[int] = Field(default=None, description="Optional session duration in seconds")
+
+
+class FocusSessionResponse(BaseModel):
+ id: str
+ status: str
+ app_or_site: str
+ description: str
+ message: Optional[str] = None
+ created_at: datetime
+ duration_seconds: Optional[int] = None
+
+
+class FocusSessionStatusResponse(BaseModel):
+ status: str
+
+
+class DistractionEntry(BaseModel):
+ app_or_site: str
+ total_seconds: int
+ count: int
+
+
+class FocusStatsResponse(BaseModel):
+ date: str
+ focused_minutes: int
+ distracted_minutes: int
+ session_count: int
+ focused_count: int
+ distracted_count: int
+ top_distractions: List[DistractionEntry]
+
+
+def _validate_focus_status(status: str):
+ if status not in ('focused', 'distracted'):
+ raise HTTPException(status_code=400, detail="status must be 'focused' or 'distracted'")
+
+
+@router.post('/v1/focus-sessions', response_model=FocusSessionResponse, tags=['focus-sessions'])
+def create_focus_session(
+ request: CreateFocusSessionRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ _validate_focus_status(request.status)
+ try:
+ session = focus_sessions_db.create_focus_session(uid, request.model_dump())
+ return session
+ except Exception:
+ logger.exception('Failed to create focus session for uid=%s', uid)
+ raise HTTPException(status_code=500, detail="Failed to create focus session")
+
+
+@router.get('/v1/focus-sessions', response_model=List[FocusSessionResponse], tags=['focus-sessions'])
+def get_focus_sessions(
+ limit: int = Query(default=100, ge=1, le=1000),
+ offset: int = Query(default=0, ge=0),
+ date: Optional[str] = Query(default=None, description="Filter by date (YYYY-MM-DD)"),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ if date:
+ try:
+ datetime.strptime(date, '%Y-%m-%d')
+ except ValueError:
+ date = None # Skip invalid date filter (match Rust behavior)
+ try:
+ return focus_sessions_db.get_focus_sessions(uid, limit=limit, offset=offset, date=date)
+ except Exception:
+ logger.exception('Failed to get focus sessions for uid=%s', uid)
+ return []
+
+
+@router.delete('/v1/focus-sessions/{session_id}', response_model=FocusSessionStatusResponse, tags=['focus-sessions'])
+def delete_focus_session(
+ session_id: str,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ try:
+ focus_sessions_db.delete_focus_session(uid, session_id)
+ return FocusSessionStatusResponse(status="ok")
+ except Exception:
+ logger.exception('Failed to delete focus session %s for uid=%s', session_id, uid)
+ raise HTTPException(status_code=500, detail="Failed to delete focus session")
+
+
+@router.get('/v1/focus-stats', response_model=FocusStatsResponse, tags=['focus-sessions'])
+def get_focus_stats(
+ date: Optional[str] = Query(default=None, description="Date for stats (YYYY-MM-DD), defaults to today"),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ if date:
+ try:
+ datetime.strptime(date, '%Y-%m-%d')
+ except ValueError:
+ date = None # Skip invalid date filter (match Rust behavior)
+ if not date:
+ date = datetime.now(timezone.utc).strftime('%Y-%m-%d')
+
+ try:
+ sessions = focus_sessions_db.get_focus_sessions_for_stats(uid, date)
+ except Exception:
+ logger.exception('Failed to get focus stats for uid=%s', uid)
+ raise HTTPException(status_code=500, detail="Failed to get focus stats")
+
+ focused_count = 0
+ distracted_count = 0
+ distraction_map = defaultdict(lambda: {'total_seconds': 0, 'count': 0})
+
+ for s in sessions:
+ status = s.get('status', '')
+ if status == 'focused':
+ focused_count += 1
+ elif status == 'distracted':
+ distracted_count += 1
+ app = s.get('app_or_site', 'Unknown')
+ raw_duration = s.get('duration_seconds')
+ duration = raw_duration if raw_duration is not None else 60
+ distraction_map[app]['total_seconds'] += duration
+ distraction_map[app]['count'] += 1
+
+ top_distractions = sorted(
+ [DistractionEntry(app_or_site=app, **vals) for app, vals in distraction_map.items()],
+ key=lambda d: d.total_seconds,
+ reverse=True,
+ )[:5]
+
+ return FocusStatsResponse(
+ date=date,
+ focused_minutes=focused_count,
+ distracted_minutes=distracted_count,
+ session_count=focused_count + distracted_count,
+ focused_count=focused_count,
+ distracted_count=distracted_count,
+ top_distractions=top_distractions,
+ )
diff --git a/backend/routers/screen_activity.py b/backend/routers/screen_activity.py
new file mode 100644
index 0000000000..81105bdddc
--- /dev/null
+++ b/backend/routers/screen_activity.py
@@ -0,0 +1,74 @@
+import logging
+import threading
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel, Field
+
+import database.screen_activity as screen_activity_db
+import database.vector_db as vector_db
+from utils.other import endpoints as auth
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+class ScreenActivityRow(BaseModel):
+ id: int = Field(description="Screenshot ID (used as Firestore document ID)")
+ timestamp: str = Field(description="Timestamp in RFC3339 or 'YYYY-MM-DD HH:MM:SS' format")
+ appName: str = Field(default='', description="Application name")
+ windowTitle: str = Field(default='', description="Window title")
+ ocrText: str = Field(default='', description="OCR text from screenshot (truncated to 1000 chars)")
+ embedding: Optional[List[float]] = Field(default=None, description="Optional vector embedding (3072-dim Gemini)")
+
+
+class ScreenActivitySyncRequest(BaseModel):
+ rows: List[ScreenActivityRow]
+
+
+class ScreenActivitySyncResponse(BaseModel):
+ synced: int = Field(description="Number of rows written to Firestore")
+ last_id: int = Field(description="Maximum row ID from the batch")
+
+
+@router.post('/v1/screen-activity/sync', response_model=ScreenActivitySyncResponse, tags=['screen-activity'])
+def sync_screen_activity(
+ request: ScreenActivitySyncRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ if len(request.rows) > 100:
+ raise HTTPException(status_code=400, detail="Maximum 100 rows per batch")
+
+ if not request.rows:
+ return ScreenActivitySyncResponse(synced=0, last_id=0)
+
+ # Convert Pydantic models to dicts for database layer
+ rows_data = [row.model_dump() for row in request.rows]
+
+ # Firestore upsert (synchronous — blocks response until written)
+ try:
+ synced = screen_activity_db.upsert_screen_activity(uid, rows_data)
+ except Exception:
+ logger.exception('Firestore upsert failed for uid=%s', uid)
+ raise HTTPException(status_code=500, detail="Failed to sync screen activity")
+
+ # Pinecone vector upsert (fire-and-forget background thread)
+ rows_with_embeddings = [r for r in rows_data if r.get('embedding')]
+ if rows_with_embeddings:
+ thread = threading.Thread(
+ target=_upsert_vectors_background,
+ args=(uid, rows_with_embeddings),
+ daemon=True,
+ )
+ thread.start()
+
+ last_id = max(row.id for row in request.rows)
+ return ScreenActivitySyncResponse(synced=synced, last_id=last_id)
+
+
+def _upsert_vectors_background(uid: str, rows: list):
+ try:
+ vector_db.upsert_screen_activity_vectors(uid, rows)
+ except Exception:
+ logger.exception('Failed to upsert screen activity vectors for uid=%s', uid)
diff --git a/backend/routers/staged_tasks.py b/backend/routers/staged_tasks.py
new file mode 100644
index 0000000000..9078187bc7
--- /dev/null
+++ b/backend/routers/staged_tasks.py
@@ -0,0 +1,295 @@
+"""Desktop staged tasks endpoints.
+
+Staged tasks are AI-extracted action items ranked by relevance_score.
+The top-ranked task can be promoted to action_items (max 5 active AI tasks).
+Deduplication prevents promoting tasks whose description already exists in active action_items.
+"""
+
+import logging
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel, Field, field_validator
+from typing import Optional, List
+from datetime import datetime, timedelta
+
+import database.staged_tasks as staged_tasks_db
+from utils.other import endpoints as auth
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+# --- Models ---
+
+
+class CreateStagedTaskRequest(BaseModel):
+ description: str = Field(..., min_length=1, max_length=2000)
+ due_at: Optional[datetime] = None
+ source: Optional[str] = None
+ priority: Optional[str] = None
+ metadata: Optional[str] = None
+ category: Optional[str] = None
+ relevance_score: Optional[int] = None
+
+ @field_validator('description')
+ @classmethod
+ def description_not_blank(cls, v):
+ if not v.strip():
+ raise ValueError('description must not be blank')
+ return v
+
+
+class StagedTaskResponse(BaseModel):
+ id: str
+ description: str
+ completed: bool = False
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+ due_at: Optional[datetime] = None
+ source: Optional[str] = None
+ priority: Optional[str] = None
+ metadata: Optional[str] = None
+ category: Optional[str] = None
+ relevance_score: Optional[int] = None
+
+
+class StagedTasksListResponse(BaseModel):
+ items: List[StagedTaskResponse]
+ has_more: bool
+
+
+class StatusResponse(BaseModel):
+ status: str
+
+
+class ScoreUpdate(BaseModel):
+ id: str
+ relevance_score: int
+
+
+class BatchUpdateScoresRequest(BaseModel):
+ scores: List[ScoreUpdate] = Field(..., min_length=1, max_length=500)
+
+
+class PromoteResponse(BaseModel):
+ promoted: bool
+ reason: Optional[str] = None
+ promoted_task: Optional[StagedTaskResponse] = None
+
+
+# --- Endpoints ---
+
+
+# --- Desktop staged tasks ---
+
+
+@router.post('/v1/staged-tasks', response_model=StagedTaskResponse, tags=['staged-tasks'])
+def create_staged_task(request: CreateStagedTaskRequest, uid: str = Depends(auth.get_current_user_uid)):
+ """Create a new staged task."""
+ data = {
+ 'description': request.description.strip(),
+ 'source': request.source,
+ 'priority': request.priority,
+ 'metadata': request.metadata,
+ 'category': request.category,
+ 'relevance_score': request.relevance_score,
+ }
+ if request.due_at:
+ data['due_at'] = request.due_at
+
+ result = staged_tasks_db.create_staged_task(uid, data)
+ return StagedTaskResponse(**result)
+
+
+@router.get('/v1/staged-tasks', response_model=StagedTasksListResponse, tags=['staged-tasks'])
+def get_staged_tasks(
+ limit: int = Query(default=100, ge=1, le=500),
+ offset: int = Query(default=0, ge=0),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """List staged tasks ordered by relevance_score ASC (best ranked first)."""
+ items, has_more = staged_tasks_db.get_staged_tasks(uid, limit=limit, offset=offset)
+ return StagedTasksListResponse(
+ items=[StagedTaskResponse(**item) for item in items],
+ has_more=has_more,
+ )
+
+
+@router.delete('/v1/staged-tasks/{task_id}', response_model=StatusResponse, tags=['staged-tasks'])
+def delete_staged_task(task_id: str, uid: str = Depends(auth.get_current_user_uid)):
+ """Hard-delete a staged task. Idempotent — returns ok even if not found (matches Rust)."""
+ staged_tasks_db.delete_staged_task(uid, task_id)
+ return StatusResponse(status='ok')
+
+
+@router.patch('/v1/staged-tasks/batch-scores', response_model=StatusResponse, tags=['staged-tasks'])
+def batch_update_scores(request: BatchUpdateScoresRequest, uid: str = Depends(auth.get_current_user_uid)):
+ """Batch update relevance scores for staged tasks."""
+ scores = [{'id': s.id, 'relevance_score': s.relevance_score} for s in request.scores]
+ staged_tasks_db.batch_update_scores(uid, scores)
+ return StatusResponse(status='ok')
+
+
+@router.post('/v1/staged-tasks/promote', response_model=PromoteResponse, tags=['staged-tasks'])
+def promote_staged_task(uid: str = Depends(auth.get_current_user_uid)):
+ """Promote the top-ranked staged task to action_items.
+
+ Rules:
+ - Max 5 active AI tasks (from_staged=true, not completed, not deleted)
+ - Skips duplicates (case-insensitive description match, strips [screen] prefix/suffix)
+ - Deletes duplicate staged tasks found during scan
+ - Hard-deletes the promoted task from staged_tasks
+ """
+ # Step 1: Check active AI task count
+ active_items = staged_tasks_db.get_active_ai_action_items(uid)
+ if len(active_items) >= 5:
+ return PromoteResponse(
+ promoted=False,
+ reason=f'Already have {len(active_items)} active AI tasks (max 5)',
+ )
+
+ # Build dedup set from existing descriptions
+ existing_descriptions = set()
+ for item in active_items:
+ desc = item.get('description', '')
+ normalized = desc.strip().removeprefix('[screen] ').removesuffix(' [screen]').lower()
+ existing_descriptions.add(normalized)
+
+ # Step 2: Get top-ranked staged tasks (batch of 20 for dedup scanning)
+ staged_items, _ = staged_tasks_db.get_staged_tasks(uid, limit=20, offset=0)
+ if not staged_items:
+ return PromoteResponse(promoted=False, reason='No staged tasks available')
+
+ # Step 3: Find first non-duplicate, collecting duplicates to delete
+ selected_task = None
+ seen_descriptions = set()
+ duplicate_ids = []
+
+ for task in staged_items:
+ normalized = task.get('description', '').strip().removeprefix('[screen] ').removesuffix(' [screen]').lower()
+ if normalized in existing_descriptions or normalized in seen_descriptions:
+ duplicate_ids.append(task['id'])
+ continue
+ seen_descriptions.add(normalized)
+ if selected_task is None:
+ selected_task = task
+
+ # Clean up duplicates
+ if duplicate_ids:
+ staged_tasks_db.delete_staged_tasks_batch(uid, duplicate_ids)
+ logger.info(f'Cleaned up {len(duplicate_ids)} duplicate staged tasks for user {uid}')
+
+ if selected_task is None:
+ return PromoteResponse(promoted=False, reason='All candidate staged tasks are duplicates')
+
+ # Step 4: Promote to action_items
+ promoted_item = staged_tasks_db.promote_staged_task(uid, selected_task)
+
+ # Step 5: Hard-delete from staged_tasks
+ staged_tasks_db.delete_staged_task(uid, selected_task['id'])
+
+ logger.info(f'Promoted staged task {selected_task["id"]} -> action item {promoted_item["id"]} for user {uid}')
+
+ return PromoteResponse(promoted=True, promoted_task=StagedTaskResponse(**promoted_item))
+
+
+# --- Desktop daily scores ---
+
+
+class DailyScoreResponse(BaseModel):
+ score: float
+ completed_tasks: int
+ total_tasks: int
+ date: str
+
+
+class ScoreData(BaseModel):
+ score: float
+ completed_tasks: int
+ total_tasks: int
+
+
+class ScoresResponse(BaseModel):
+ daily: ScoreData
+ weekly: ScoreData
+ overall: ScoreData
+ default_tab: str
+ date: str
+
+
+@router.get('/v1/daily-score', response_model=DailyScoreResponse, tags=['scores'])
+def get_daily_score(
+ date: Optional[str] = Query(default=None, description='Date in YYYY-MM-DD format'),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Calculate daily score from action items due today (legacy endpoint)."""
+ if date:
+ try:
+ parsed = datetime.strptime(date, '%Y-%m-%d').date()
+ except ValueError:
+ raise HTTPException(status_code=400, detail='Invalid date format, use YYYY-MM-DD')
+ else:
+ parsed = datetime.now().date()
+
+ date_str = parsed.strftime('%Y-%m-%d')
+ due_start = f'{date_str}T00:00:00Z'
+ due_end = f'{date_str}T23:59:59.999Z'
+
+ completed, total = staged_tasks_db.get_action_items_for_daily_score(uid, due_start, due_end)
+ score = (completed / total * 100.0) if total > 0 else 0.0
+
+ return DailyScoreResponse(score=score, completed_tasks=completed, total_tasks=total, date=date_str)
+
+
+@router.get('/v1/scores', response_model=ScoresResponse, tags=['scores'])
+def get_scores(
+ date: Optional[str] = Query(default=None, description='Date in YYYY-MM-DD format'),
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ """Get daily, weekly, and overall scores with default tab selection."""
+ if date:
+ try:
+ parsed = datetime.strptime(date, '%Y-%m-%d').date()
+ except ValueError:
+ raise HTTPException(status_code=400, detail='Invalid date format, use YYYY-MM-DD')
+ else:
+ parsed = datetime.now().date()
+
+ date_str = parsed.strftime('%Y-%m-%d')
+
+ # Daily: tasks due today
+ today_start = f'{date_str}T00:00:00Z'
+ today_end = f'{date_str}T23:59:59.999Z'
+ daily_completed, daily_total = staged_tasks_db.get_action_items_for_daily_score(uid, today_start, today_end)
+
+ # Weekly: last 7 days
+ week_ago = parsed - timedelta(days=7)
+ week_start = f'{week_ago.strftime("%Y-%m-%d")}T00:00:00Z'
+ weekly_completed, weekly_total = staged_tasks_db.get_action_items_for_weekly_score(uid, week_start, today_end)
+
+ # Overall: all time
+ overall_completed, overall_total = staged_tasks_db.get_action_items_for_overall_score(uid)
+
+ def calc_score(completed, total):
+ return (completed / total * 100.0) if total > 0 else 0.0
+
+ daily = ScoreData(
+ score=calc_score(daily_completed, daily_total), completed_tasks=daily_completed, total_tasks=daily_total
+ )
+ weekly = ScoreData(
+ score=calc_score(weekly_completed, weekly_total), completed_tasks=weekly_completed, total_tasks=weekly_total
+ )
+ overall = ScoreData(
+ score=calc_score(overall_completed, overall_total), completed_tasks=overall_completed, total_tasks=overall_total
+ )
+
+ # Default tab: highest score, prefer daily if tied
+ if daily.total_tasks > 0 and daily.score >= weekly.score and daily.score >= overall.score:
+ default_tab = 'daily'
+ elif weekly.score >= overall.score:
+ default_tab = 'weekly'
+ else:
+ default_tab = 'overall'
+
+ return ScoresResponse(daily=daily, weekly=weekly, overall=overall, default_tab=default_tab, date=date_str)
diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py
index 24914a7702..fbd85413a6 100644
--- a/backend/routers/transcribe.py
+++ b/backend/routers/transcribe.py
@@ -51,16 +51,24 @@
TranscriptSegment,
)
from models.message_event import (
+ AdviceExtractedEvent,
ConversationEvent,
+ DedupCompleteEvent,
+ FocusResultEvent,
FREEMIUM_ACTION_SETUP_ON_DEVICE_STT,
FreemiumThresholdReachedEvent,
LastConversationEvent,
+ LiveNoteEvent,
+ MemoriesExtractedEvent,
MessageEvent,
MessageServiceStatusEvent,
PhotoDescribedEvent,
PhotoProcessingEvent,
+ ProfileUpdatedEvent,
+ RerankCompleteEvent,
SegmentsDeletedEvent,
SpeakerLabelSuggestionEvent,
+ TasksExtractedEvent,
TranslationEvent,
)
from models.transcript_segment import Translation
@@ -101,6 +109,13 @@
SPEAKER_MATCH_THRESHOLD,
)
from utils.speaker_sample_migration import maybe_migrate_person_samples
+from utils.desktop.advice import generate_advice
+from utils.desktop.focus import analyze_focus
+from utils.desktop.live_notes import generate_live_note
+from utils.desktop.memories import extract_memories
+from utils.desktop.profile import generate_profile
+from utils.desktop.task_ops import dedup_tasks, rerank_tasks
+from utils.desktop.tasks import extract_tasks
from utils.log_sanitizer import sanitize, sanitize_pii
logger = logging.getLogger(__name__)
@@ -2431,6 +2446,98 @@ async def close_soniox_profile():
logger.info(
f"Speaker assignment ignored: missing speaker_id/person_id/person_name. {uid} {session_id}"
)
+ # Desktop proactive AI — screen_frame analysis (#5396)
+ elif json_data.get('type') == 'screen_frame':
+ frame_id = json_data.get('frame_id', '')
+ image_b64 = json_data.get('image_b64', '')
+ analyze_types = json_data.get('analyze', [])
+ sf_app = json_data.get('app_name', '')
+ sf_wtitle = json_data.get('window_title', '')
+ if not image_b64:
+ logger.warning(f"screen_frame missing image_b64 {uid} {session_id}")
+ else:
+ # Fan out to parallel handlers per analyze type
+ if 'focus' in analyze_types:
+ async def _handle_focus(fid, img, app, wtitle):
+ try:
+ result = await analyze_focus(uid=uid, image_b64=img, app_name=app, window_title=wtitle)
+ _send_message_event(FocusResultEvent(
+ frame_id=fid, status=result['status'], app_or_site=result['app_or_site'],
+ description=result['description'], message=result.get('message'),
+ ))
+ except Exception as e:
+ logger.error(f"Focus analysis failed: {e} {uid} {session_id}")
+ spawn(_handle_focus(frame_id, image_b64, sf_app, sf_wtitle))
+
+ if 'tasks' in analyze_types:
+ async def _handle_tasks(fid, img, app, wtitle):
+ try:
+ result = await extract_tasks(uid=uid, image_b64=img, app_name=app, window_title=wtitle)
+ _send_message_event(TasksExtractedEvent(frame_id=fid, tasks=result.get('tasks', [])))
+ except Exception as e:
+ logger.error(f"Task extraction failed: {e} {uid} {session_id}")
+ spawn(_handle_tasks(frame_id, image_b64, sf_app, sf_wtitle))
+
+ if 'memories' in analyze_types:
+ async def _handle_memories(fid, img, app, wtitle):
+ try:
+ result = await extract_memories(uid=uid, image_b64=img, app_name=app, window_title=wtitle)
+ _send_message_event(MemoriesExtractedEvent(frame_id=fid, memories=result.get('memories', [])))
+ except Exception as e:
+ logger.error(f"Memory extraction failed: {e} {uid} {session_id}")
+ spawn(_handle_memories(frame_id, image_b64, sf_app, sf_wtitle))
+
+ if 'advice' in analyze_types:
+ async def _handle_advice(fid, img, app, wtitle):
+ try:
+ result = await generate_advice(uid=uid, image_b64=img, app_name=app, window_title=wtitle)
+ _send_message_event(AdviceExtractedEvent(
+ frame_id=fid, advice=result.get('advice'),
+ ))
+ except Exception as e:
+ logger.error(f"Advice generation failed: {e} {uid} {session_id}")
+ spawn(_handle_advice(frame_id, image_b64, sf_app, sf_wtitle))
+
+ # Desktop proactive AI — text-only message types (#5396)
+ elif json_data.get('type') == 'live_notes_text':
+ async def _handle_live_notes(text, ctx):
+ try:
+ result = await generate_live_note(text=text, session_context=ctx)
+ if result.get('text'):
+ _send_message_event(LiveNoteEvent(text=result['text']))
+ except Exception as e:
+ logger.error(f"Live note generation failed: {e} {uid} {session_id}")
+ spawn(_handle_live_notes(json_data.get('text', ''), json_data.get('session_context', '')))
+
+ elif json_data.get('type') == 'profile_request':
+ async def _handle_profile():
+ try:
+ result = await generate_profile(uid=uid)
+ _send_message_event(ProfileUpdatedEvent(profile_text=result['profile_text']))
+ except Exception as e:
+ logger.error(f"Profile generation failed: {e} {uid} {session_id}")
+ spawn(_handle_profile())
+
+ elif json_data.get('type') == 'task_rerank':
+ async def _handle_rerank():
+ try:
+ result = await rerank_tasks(uid=uid)
+ _send_message_event(RerankCompleteEvent(updated_tasks=result['updated_tasks']))
+ except Exception as e:
+ logger.error(f"Task reranking failed: {e} {uid} {session_id}")
+ spawn(_handle_rerank())
+
+ elif json_data.get('type') == 'task_dedup':
+ async def _handle_dedup():
+ try:
+ result = await dedup_tasks(uid=uid)
+ _send_message_event(DedupCompleteEvent(
+ deleted_ids=result['deleted_ids'], reason=result['reason'],
+ ))
+ except Exception as e:
+ logger.error(f"Task dedup failed: {e} {uid} {session_id}")
+ spawn(_handle_dedup())
+
except json.JSONDecodeError:
logger.info(
f"Received non-json text message: {sanitize(message.get('text'))} {uid} {session_id}"
diff --git a/backend/routers/users.py b/backend/routers/users.py
index 5958f7261d..62a274c856 100644
--- a/backend/routers/users.py
+++ b/backend/routers/users.py
@@ -1,4 +1,5 @@
import json
+import re
import threading
import uuid
from typing import List, Dict, Any, Union, Optional
@@ -1193,3 +1194,141 @@ def generate():
media_type='application/json',
headers={'Content-Disposition': 'attachment; filename="omi-export.json"'},
)
+
+
+# **************************************
+# ****** Assistant Settings ************
+# **************************************
+
+
+class SharedAssistantSettings(BaseModel):
+ cooldown_interval: Optional[int] = None
+ glow_overlay_enabled: Optional[bool] = None
+ analysis_delay: Optional[int] = None
+ screen_analysis_enabled: Optional[bool] = None
+
+
+class FocusSettings(BaseModel):
+ enabled: Optional[bool] = None
+ analysis_prompt: Optional[str] = None
+ cooldown_interval: Optional[int] = None
+ notifications_enabled: Optional[bool] = None
+ excluded_apps: Optional[List[str]] = None
+
+
+class TaskSettings(BaseModel):
+ enabled: Optional[bool] = None
+ analysis_prompt: Optional[str] = None
+ extraction_interval: Optional[float] = None
+ min_confidence: Optional[float] = None
+ notifications_enabled: Optional[bool] = None
+ allowed_apps: Optional[List[str]] = None
+ browser_keywords: Optional[List[str]] = None
+
+
+class AdviceSettings(BaseModel):
+ enabled: Optional[bool] = None
+ analysis_prompt: Optional[str] = None
+ extraction_interval: Optional[float] = None
+ min_confidence: Optional[float] = None
+ notifications_enabled: Optional[bool] = None
+ excluded_apps: Optional[List[str]] = None
+
+
+class MemorySettings(BaseModel):
+ enabled: Optional[bool] = None
+ analysis_prompt: Optional[str] = None
+ extraction_interval: Optional[float] = None
+ min_confidence: Optional[float] = None
+ notifications_enabled: Optional[bool] = None
+ excluded_apps: Optional[List[str]] = None
+
+
+class AssistantSettingsData(BaseModel):
+ shared: Optional[SharedAssistantSettings] = None
+ focus: Optional[FocusSettings] = None
+ task: Optional[TaskSettings] = None
+ advice: Optional[AdviceSettings] = None
+ memory: Optional[MemorySettings] = None
+ update_channel: Optional[str] = None
+
+
+def _validate_assistant_settings(data: AssistantSettingsData):
+ """Validate prompt lengths and list sizes matching Rust backend limits."""
+ for section_name in ('focus', 'task', 'advice', 'memory'):
+ section = getattr(data, section_name, None)
+ if section and section.analysis_prompt and len(section.analysis_prompt) > 10000:
+ raise HTTPException(status_code=400, detail=f'{section_name}.analysis_prompt exceeds 10000 chars')
+
+ if data.task:
+ if data.task.allowed_apps and len(data.task.allowed_apps) > 500:
+ raise HTTPException(status_code=400, detail='task.allowed_apps exceeds 500 items')
+ if data.task.browser_keywords and len(data.task.browser_keywords) > 500:
+ raise HTTPException(status_code=400, detail='task.browser_keywords exceeds 500 items')
+ if data.task.min_confidence is not None and not (0.0 <= data.task.min_confidence <= 1.0):
+ raise HTTPException(status_code=400, detail='task.min_confidence must be between 0.0 and 1.0')
+
+ if data.advice and data.advice.min_confidence is not None and not (0.0 <= data.advice.min_confidence <= 1.0):
+ raise HTTPException(status_code=400, detail='advice.min_confidence must be between 0.0 and 1.0')
+
+ if data.memory and data.memory.min_confidence is not None and not (0.0 <= data.memory.min_confidence <= 1.0):
+ raise HTTPException(status_code=400, detail='memory.min_confidence must be between 0.0 and 1.0')
+
+
+@router.get('/v1/users/assistant-settings', tags=['users'])
+def get_assistant_settings_endpoint(uid: str = Depends(auth.get_current_user_uid)):
+ return get_assistant_settings(uid)
+
+
+@router.patch('/v1/users/assistant-settings', tags=['users'])
+def update_assistant_settings_endpoint(
+ data: AssistantSettingsData,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ _validate_assistant_settings(data)
+ update_data = data.model_dump(exclude_none=True)
+ if not update_data:
+ return get_assistant_settings(uid)
+ return update_assistant_settings(uid, update_data)
+
+
+# **************************************
+# ******** AI User Profile *************
+# **************************************
+
+
+class UpdateAIProfileRequest(BaseModel):
+ profile_text: str
+ generated_at: str
+ data_sources_used: int
+
+
+@router.get('/v1/users/ai-profile', tags=['users'])
+def get_ai_profile_endpoint(uid: str = Depends(auth.get_current_user_uid)):
+ return get_ai_user_profile(uid)
+
+
+@router.patch('/v1/users/ai-profile', tags=['users'])
+def update_ai_profile_endpoint(
+ data: UpdateAIProfileRequest,
+ uid: str = Depends(auth.get_current_user_uid),
+):
+ # Strict RFC3339: YYYY-MM-DDTHH:MM:SS[.frac](Z|+HH:MM|-HH:MM)
+ ts = data.generated_at
+ if not re.fullmatch(r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?(Z|[+-]\d{2}:\d{2})', ts):
+ raise HTTPException(status_code=400, detail="generated_at must be a valid RFC3339 timestamp")
+ try:
+ parsed_ts = datetime.fromisoformat(ts.replace('Z', '+00:00'))
+ except (ValueError, AttributeError):
+ raise HTTPException(status_code=400, detail="generated_at must be a valid RFC3339 timestamp")
+
+ # Truncate profile_text to 10000 bytes (matching Rust behavior — truncate, not reject)
+ profile_bytes = data.profile_text.encode('utf-8')[:10000]
+ profile_text = profile_bytes.decode('utf-8', errors='ignore')
+
+ profile_data = {
+ 'profile_text': profile_text,
+ 'generated_at': parsed_ts,
+ 'data_sources_used': data.data_sources_used,
+ }
+ return update_ai_user_profile(uid, profile_data)
diff --git a/backend/templates/auth_callback.html b/backend/templates/auth_callback.html
index 7f5a820e6b..e4a591b0f7 100644
--- a/backend/templates/auth_callback.html
+++ b/backend/templates/auth_callback.html
@@ -108,17 +108,29 @@
Authentication Successful
spinnerElement.style.display = 'none';
messageElement.textContent = 'Please close this window and try again.';
} else if (code) {
- // Build the custom scheme redirect URL
- let redirectUrl = 'omi://auth/callback?code=' + encodeURIComponent(code);
+ // Build the custom scheme redirect URL using the redirect_uri from the auth session
+ const redirectUri = {{ redirect_uri|tojson }};
+
+ // Validate redirect scheme before use (defense-in-depth; server also validates)
+ const ALLOWED_SCHEMES = ['omi://', 'omi-computer://', 'omi-computer-dev://'];
+ const isAllowedScheme = ALLOWED_SCHEMES.some(s => redirectUri.startsWith(s));
+ if (!isAllowedScheme) {
+ errorElement.textContent = 'Invalid redirect scheme.';
+ spinnerElement.style.display = 'none';
+ messageElement.textContent = 'Please close this window and try again.';
+ }
+
+ let redirectUrl = redirectUri + '?code=' + encodeURIComponent(code);
if (state) {
redirectUrl += '&state=' + encodeURIComponent(state);
}
// Set up manual link
- manualLinkElement.href = redirectUrl;
+ manualLinkElement.href = isAllowedScheme ? redirectUrl : '#';
// Attempt automatic redirect
try {
+ if (!isAllowedScheme) throw new Error('Blocked redirect to disallowed scheme');
console.log('Redirecting to:', redirectUrl);
window.location.href = redirectUrl;
diff --git a/backend/test.sh b/backend/test.sh
index e431460b1e..aa5aa45627 100755
--- a/backend/test.sh
+++ b/backend/test.sh
@@ -43,3 +43,21 @@ pytest tests/unit/test_storage_upload_audio_chunk_data_protection.py -v
pytest tests/unit/test_people_conversations_500s.py -v
pytest tests/unit/test_firestore_read_ops_cache.py -v
pytest tests/unit/test_ws_auth_handshake.py -v
+pytest tests/unit/test_auth_routes.py -v
+pytest tests/unit/test_from_segments.py -v
+pytest tests/unit/test_desktop_chat.py -v
+pytest tests/unit/test_screen_activity_sync.py -v
+pytest tests/unit/test_assistant_settings_ai_profile.py -v
+pytest tests/unit/test_focus_sessions.py -v
+pytest tests/unit/test_advice.py -v
+pytest tests/unit/test_staged_tasks.py -v
+pytest tests/unit/test_chat_generate_title.py -v
+pytest tests/unit/test_conversations_count.py -v
+pytest tests/unit/test_desktop_focus.py -v
+pytest tests/unit/test_desktop_tasks.py -v
+pytest tests/unit/test_desktop_memories.py -v
+pytest tests/unit/test_desktop_advice.py -v
+pytest tests/unit/test_desktop_live_notes.py -v
+pytest tests/unit/test_desktop_profile.py -v
+pytest tests/unit/test_desktop_task_ops.py -v
+pytest tests/unit/test_agent_vm.py -v
diff --git a/backend/tests/unit/test_advice.py b/backend/tests/unit/test_advice.py
new file mode 100644
index 0000000000..657d909a90
--- /dev/null
+++ b/backend/tests/unit/test_advice.py
@@ -0,0 +1,237 @@
+import sys
+from datetime import datetime, timezone
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+for mod_name in [
+ 'firebase_admin',
+ 'firebase_admin.auth',
+ 'firebase_admin.firestore',
+ 'firebase_admin.messaging',
+ 'google.cloud',
+ 'google.cloud.exceptions',
+ 'google.cloud.firestore',
+ 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query',
+ 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage',
+ 'google.cloud.storage.blob',
+ 'google.cloud.storage.bucket',
+ 'google.auth',
+ 'google.auth.transport',
+ 'google.auth.transport.requests',
+ 'google.oauth2',
+ 'google.oauth2.service_account',
+ 'pinecone',
+ 'typesense',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+from routers.advice import router
+
+from fastapi import FastAPI, HTTPException
+from fastapi.testclient import TestClient
+
+
+@pytest.fixture
+def client():
+ with patch('routers.advice.auth.get_current_user_uid', return_value='uid-1'):
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+@pytest.fixture
+def client_no_auth():
+ """Client without auth mock — for testing 401 responses."""
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+AUTH = {"Authorization": "Bearer 123testuser"}
+
+
+class TestCreateAdvice:
+ def test_create_minimal(self, client):
+ data = {"content": "Take a break"}
+ with patch('routers.advice.advice_db.create_advice') as mock_create:
+ mock_create.return_value = {
+ "id": "adv-1", "content": "Take a break", "category": "other",
+ "confidence": 0.5, "is_read": False, "is_dismissed": False,
+ "created_at": datetime.now(timezone.utc),
+ }
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["content"] == "Take a break"
+ assert resp.json()["category"] == "other"
+
+ def test_create_with_all_fields(self, client):
+ data = {
+ "content": "Drink water", "category": "health", "reasoning": "Dehydrated",
+ "source_app": "Chrome", "confidence": 0.9, "context_summary": "Long session",
+ "current_activity": "Browsing",
+ }
+ with patch('routers.advice.advice_db.create_advice') as mock_create:
+ mock_create.return_value = {"id": "adv-2", **data, "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["category"] == "health"
+ assert resp.json()["confidence"] == 0.9
+
+ def test_create_invalid_category_returns_400(self, client):
+ data = {"content": "Test", "category": "invalid_cat"}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "category" in resp.json()["detail"]
+
+ def test_create_confidence_below_zero_returns_400(self, client):
+ data = {"content": "Test", "confidence": -0.1}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_create_confidence_above_one_returns_400(self, client):
+ data = {"content": "Test", "confidence": 1.1}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_create_confidence_boundary_zero(self, client):
+ data = {"content": "Test", "confidence": 0.0}
+ with patch('routers.advice.advice_db.create_advice') as mock_create:
+ mock_create.return_value = {"id": "adv-3", "content": "Test", "confidence": 0.0, "category": "other", "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_create_confidence_boundary_one(self, client):
+ data = {"content": "Test", "confidence": 1.0}
+ with patch('routers.advice.advice_db.create_advice') as mock_create:
+ mock_create.return_value = {"id": "adv-4", "content": "Test", "confidence": 1.0, "category": "other", "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_create_no_auth_returns_401(self, client_no_auth):
+ with patch(
+ 'routers.advice.auth.get_current_user_uid',
+ side_effect=HTTPException(status_code=401, detail='Not authenticated'),
+ ):
+ resp = client_no_auth.post("/v1/advice", json={"content": "Test"})
+ assert resp.status_code == 401
+
+ def test_create_firestore_error_returns_500(self, client):
+ with patch('routers.advice.advice_db.create_advice', side_effect=Exception("DB down")):
+ resp = client.post("/v1/advice", json={"content": "Test"}, headers=AUTH)
+ assert resp.status_code == 500
+
+ def test_create_each_valid_category(self, client):
+ for cat in ('productivity', 'health', 'communication', 'learning', 'other'):
+ data = {"content": "Test", "category": cat}
+ with patch('routers.advice.advice_db.create_advice') as mock_create:
+ mock_create.return_value = {"id": "x", "content": "Test", "category": cat, "confidence": 0.5, "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)}
+ resp = client.post("/v1/advice", json=data, headers=AUTH)
+ assert resp.status_code == 200, f"Failed for category {cat}"
+
+
+class TestGetAdvice:
+ def test_get_empty(self, client):
+ with patch('routers.advice.advice_db.get_advice', return_value=[]):
+ resp = client.get("/v1/advice", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() == []
+
+ def test_get_with_category_filter(self, client):
+ with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get:
+ resp = client.get("/v1/advice?category=health", headers=AUTH)
+ assert resp.status_code == 200
+ assert mock_get.call_args[1]['category'] == 'health'
+
+ def test_get_invalid_category_skips_filter(self, client):
+ with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get:
+ resp = client.get("/v1/advice?category=bad_cat", headers=AUTH)
+ assert resp.status_code == 200
+ assert mock_get.call_args[1]['category'] is None
+
+ def test_get_include_dismissed(self, client):
+ with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get:
+ resp = client.get("/v1/advice?include_dismissed=true", headers=AUTH)
+ assert resp.status_code == 200
+ assert mock_get.call_args[1]['include_dismissed'] is True
+
+ def test_get_with_pagination(self, client):
+ with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get:
+ resp = client.get("/v1/advice?limit=50&offset=20", headers=AUTH)
+ assert resp.status_code == 200
+ assert mock_get.call_args[1]['limit'] == 50
+ assert mock_get.call_args[1]['offset'] == 20
+
+ def test_get_firestore_error_returns_empty(self, client):
+ with patch('routers.advice.advice_db.get_advice', side_effect=Exception("err")):
+ resp = client.get("/v1/advice", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() == []
+
+
+class TestUpdateAdvice:
+ def test_mark_as_read(self, client):
+ with patch('routers.advice.advice_db.update_advice') as mock_update:
+ mock_update.return_value = {"id": "adv-1", "is_read": True, "is_dismissed": False, "content": "x", "category": "other", "confidence": 0.5, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc)}
+ resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["is_read"] is True
+
+ def test_mark_as_dismissed(self, client):
+ with patch('routers.advice.advice_db.update_advice') as mock_update:
+ mock_update.return_value = {"id": "adv-1", "is_read": False, "is_dismissed": True, "content": "x", "category": "other", "confidence": 0.5, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc)}
+ resp = client.patch("/v1/advice/adv-1", json={"is_dismissed": True}, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["is_dismissed"] is True
+
+ def test_empty_update_still_updates_timestamp(self, client):
+ with patch('routers.advice.advice_db.update_advice') as mock_update:
+ mock_update.return_value = {"id": "adv-1", "is_read": False, "is_dismissed": False, "content": "x", "category": "other", "confidence": 0.5, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc)}
+ resp = client.patch("/v1/advice/adv-1", json={}, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_update_not_found_returns_500(self, client):
+ with patch('routers.advice.advice_db.update_advice', return_value=None):
+ resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH)
+ assert resp.status_code == 500
+
+ def test_update_firestore_error_returns_500(self, client):
+ with patch('routers.advice.advice_db.update_advice', side_effect=Exception("err")):
+ resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH)
+ assert resp.status_code == 500
+
+
+class TestDeleteAdvice:
+ def test_delete_returns_ok(self, client):
+ with patch('routers.advice.advice_db.delete_advice', return_value=True):
+ resp = client.delete("/v1/advice/adv-1", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["status"] == "ok"
+
+ def test_delete_firestore_error_returns_500(self, client):
+ with patch('routers.advice.advice_db.delete_advice', side_effect=Exception("err")):
+ resp = client.delete("/v1/advice/adv-1", headers=AUTH)
+ assert resp.status_code == 500
+
+
+class TestMarkAllRead:
+ def test_mark_all_read_returns_count(self, client):
+ with patch('routers.advice.advice_db.mark_all_advice_read', return_value=5):
+ resp = client.post("/v1/advice/mark-all-read", headers=AUTH)
+ assert resp.status_code == 200
+ assert "5" in resp.json()["status"]
+
+ def test_mark_all_read_zero(self, client):
+ with patch('routers.advice.advice_db.mark_all_advice_read', return_value=0):
+ resp = client.post("/v1/advice/mark-all-read", headers=AUTH)
+ assert resp.status_code == 200
+ assert "0" in resp.json()["status"]
+
+ def test_mark_all_read_firestore_error(self, client):
+ with patch('routers.advice.advice_db.mark_all_advice_read', side_effect=Exception("err")):
+ resp = client.post("/v1/advice/mark-all-read", headers=AUTH)
+ assert resp.status_code == 500
diff --git a/backend/tests/unit/test_agent_vm.py b/backend/tests/unit/test_agent_vm.py
new file mode 100644
index 0000000000..7f5bcd4303
--- /dev/null
+++ b/backend/tests/unit/test_agent_vm.py
@@ -0,0 +1,328 @@
+"""Tests for agent VM endpoints — vm-ensure and vm-status.
+
+Verifies:
+- vm-ensure creates new VMs for users with no existing VM
+- vm-ensure restarts stopped/terminated VMs
+- vm-ensure is idempotent (doesn't double-provision)
+- vm-status returns full VM fields (vm_name, ip, auth_token, zone, created_at)
+- vm-status triggers restart for stopped VMs (Rust parity)
+- Response JSON matches desktop Swift AgentProvisionResponse/AgentStatusResponse
+"""
+
+import os
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+
+
+# Stub heavy imports before importing the router
+sys.modules.setdefault('database._client', MagicMock())
+sys.modules.setdefault('database.users', MagicMock())
+sys.modules.setdefault('utils.retrieval.agentic', MagicMock(agent_config_context=MagicMock(), CORE_TOOLS=[]))
+sys.modules.setdefault('utils.retrieval.tools.app_tools', MagicMock())
+
+from routers.agent_tools import router, _vm_response, _provision_vm_background, _restart_vm_background, GCE_ZONE
+from utils.other.endpoints import get_current_user_uid
+
+app = FastAPI()
+app.include_router(router)
+
+TEST_UID = "testuser1234abcd"
+
+app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID
+client = TestClient(app)
+
+
+SAMPLE_VM = {
+ "vmName": "omi-agent-testuser1234",
+ "zone": "us-central1-a",
+ "ip": "35.192.1.1",
+ "status": "ready",
+ "authToken": "omi-abc123",
+ "createdAt": "2026-03-10T00:00:00+00:00",
+ "lastQueryAt": "2026-03-10T01:00:00+00:00",
+}
+
+
+# --------------- vm-status tests ---------------
+
+
+@patch("routers.agent_tools.get_agent_vm", return_value=None)
+def test_vm_status_no_vm(mock_get):
+ """vm-status returns has_vm=False when user has no VM."""
+ resp = client.get("/v1/agent/vm-status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is False
+
+
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING")
+@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM)
+def test_vm_status_returns_full_fields(mock_get, mock_gce):
+ """vm-status returns all fields desktop needs: vm_name, ip, auth_token, zone, created_at."""
+ resp = client.get("/v1/agent/vm-status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
+ assert data["vm_name"] == "omi-agent-testuser1234"
+ assert data["ip"] == "35.192.1.1"
+ assert data["auth_token"] == "omi-abc123"
+ assert data["zone"] == "us-central1-a"
+ assert data["created_at"] == "2026-03-10T00:00:00+00:00"
+ assert data["last_query_at"] == "2026-03-10T01:00:00+00:00"
+ assert data["status"] == "ready"
+
+
+@patch("routers.agent_tools._restart_vm_background")
+@patch("routers.agent_tools._update_firestore_vm")
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="TERMINATED")
+@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM)
+def test_vm_status_restarts_stopped_vm(mock_get, mock_gce, mock_update, mock_restart):
+ """vm-status triggers restart when GCE status is TERMINATED (Rust parity)."""
+ resp = client.get("/v1/agent/vm-status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "provisioning"
+ mock_update.assert_called_once_with(TEST_UID, None, "provisioning")
+
+
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, side_effect=Exception("GCE unreachable"))
+@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM)
+def test_vm_status_gce_failure_returns_firestore_data(mock_get, mock_gce):
+ """vm-status returns Firestore data when GCE check fails."""
+ resp = client.get("/v1/agent/vm-status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
+ assert data["status"] == "ready"
+ assert data["vm_name"] == "omi-agent-testuser1234"
+
+
+# --------------- vm-ensure tests ---------------
+
+
+@patch("routers.agent_tools._provision_vm_background")
+@patch("routers.agent_tools._set_firestore_vm")
+@patch("routers.agent_tools.get_agent_vm", return_value=None)
+def test_vm_ensure_creates_new_vm(mock_get, mock_set_fs, mock_provision):
+ """vm-ensure creates a new VM when no Firestore record exists."""
+ resp = client.post("/v1/agent/vm-ensure")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
+ assert data["status"] == "provisioning"
+ assert data["vm_name"] == "omi-agent-testuser1234"
+ assert data["auth_token"].startswith("omi-")
+ assert data["zone"] == "us-central1-a"
+ assert data["ip"] is None
+
+ # Verify Firestore was written
+ mock_set_fs.assert_called_once()
+ call_args = mock_set_fs.call_args
+ assert call_args[0][0] == TEST_UID
+ assert call_args[0][1] == "omi-agent-testuser1234"
+ assert call_args[0][4] == "provisioning"
+
+
+@patch(
+ "routers.agent_tools.get_agent_vm",
+ return_value={"vmName": "omi-agent-testuser1234", "status": "provisioning", "authToken": "omi-xyz"},
+)
+def test_vm_ensure_idempotent_provisioning(mock_get):
+ """vm-ensure doesn't double-provision when already provisioning."""
+ resp = client.post("/v1/agent/vm-ensure")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
+ assert data["status"] == "provisioning"
+
+
+@patch("routers.agent_tools._restart_vm_background")
+@patch("routers.agent_tools._update_firestore_vm")
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="STOPPED")
+@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM)
+def test_vm_ensure_restarts_stopped_vm(mock_get, mock_gce, mock_update, mock_restart):
+ """vm-ensure restarts a stopped VM."""
+ resp = client.post("/v1/agent/vm-ensure")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "provisioning"
+ mock_update.assert_called_once_with(TEST_UID, None, "provisioning")
+
+
+@patch("routers.agent_tools._update_firestore_vm")
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING")
+@patch(
+ "routers.agent_tools.get_agent_vm",
+ return_value={**SAMPLE_VM, "status": "error"},
+)
+def test_vm_ensure_recovers_running_but_error_status(mock_get, mock_gce, mock_update):
+ """vm-ensure recovers when GCE is RUNNING but Firestore says error."""
+ resp = client.post("/v1/agent/vm-ensure")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "ready"
+ mock_update.assert_called_once_with(TEST_UID, "35.192.1.1", "ready")
+
+
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING")
+@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM)
+def test_vm_ensure_returns_full_fields_for_ready_vm(mock_get, mock_gce):
+ """vm-ensure returns full VM fields when VM is ready."""
+ resp = client.post("/v1/agent/vm-ensure")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
+ assert data["vm_name"] == "omi-agent-testuser1234"
+ assert data["ip"] == "35.192.1.1"
+ assert data["auth_token"] == "omi-abc123"
+
+
+# --------------- _vm_response tests ---------------
+
+
+def test_vm_response_maps_firestore_fields():
+ """_vm_response correctly maps Firestore camelCase to snake_case."""
+ result = _vm_response(SAMPLE_VM)
+ assert result["vm_name"] == "omi-agent-testuser1234"
+ assert result["auth_token"] == "omi-abc123"
+ assert result["created_at"] == "2026-03-10T00:00:00+00:00"
+ assert result["last_query_at"] == "2026-03-10T01:00:00+00:00"
+
+
+def test_vm_response_status_override():
+ """_vm_response applies status_override correctly."""
+ result = _vm_response(SAMPLE_VM, status_override="provisioning")
+ assert result["status"] == "provisioning"
+ assert result["vm_name"] == "omi-agent-testuser1234"
+
+
+# --------------- vm_name generation tests ---------------
+
+
+@patch("routers.agent_tools._provision_vm_background")
+@patch("routers.agent_tools._set_firestore_vm")
+@patch("routers.agent_tools.get_agent_vm", return_value=None)
+def test_vm_name_truncates_long_uid(mock_get, mock_set_fs, mock_provision):
+ """VM name uses first 12 chars of UID, lowercased."""
+ app.dependency_overrides[get_current_user_uid] = lambda: "ABCDEFghijklmnopqrstuvwxyz"
+ try:
+ resp = client.post("/v1/agent/vm-ensure")
+ data = resp.json()
+ assert data["vm_name"] == "omi-agent-abcdefghijkl"
+ finally:
+ app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID
+
+
+@patch("routers.agent_tools._provision_vm_background")
+@patch("routers.agent_tools._set_firestore_vm")
+@patch("routers.agent_tools.get_agent_vm", return_value=None)
+def test_vm_name_short_uid(mock_get, mock_set_fs, mock_provision):
+ """Short UIDs use the full UID in VM name."""
+ app.dependency_overrides[get_current_user_uid] = lambda: "ShortUid"
+ try:
+ resp = client.post("/v1/agent/vm-ensure")
+ data = resp.json()
+ assert data["vm_name"] == "omi-agent-shortuid"
+ finally:
+ app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID
+
+
+# --------------- UID boundary tests ---------------
+
+
+@patch("routers.agent_tools._provision_vm_background")
+@patch("routers.agent_tools._set_firestore_vm")
+@patch("routers.agent_tools.get_agent_vm", return_value=None)
+def test_vm_name_whitespace_uid(mock_get, mock_set_fs, mock_provision):
+ """UID with whitespace is lowercased and truncated normally."""
+ app.dependency_overrides[get_current_user_uid] = lambda: "User With Spaces"
+ try:
+ resp = client.post("/v1/agent/vm-ensure")
+ data = resp.json()
+ assert data["vm_name"] == "omi-agent-user with sp"
+ finally:
+ app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID
+
+
+@patch("routers.agent_tools._provision_vm_background")
+@patch("routers.agent_tools._set_firestore_vm")
+@patch("routers.agent_tools.get_agent_vm", return_value=None)
+def test_vm_name_empty_string_uid(mock_get, mock_set_fs, mock_provision):
+ """Empty-string UID produces omi-agent- prefix with empty suffix."""
+ app.dependency_overrides[get_current_user_uid] = lambda: ""
+ try:
+ resp = client.post("/v1/agent/vm-ensure")
+ data = resp.json()
+ assert data["vm_name"] == "omi-agent-"
+ finally:
+ app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID
+
+
+# --------------- background task error handling tests ---------------
+
+
+@pytest.mark.asyncio
+@patch("routers.agent_tools._update_firestore_vm")
+@patch("routers.agent_tools._create_gce_vm", new_callable=AsyncMock, side_effect=Exception("GCE insert timed out"))
+async def test_provision_vm_background_sets_error_on_failure(mock_create, mock_update):
+ """_provision_vm_background sets Firestore status to 'error' when GCE creation fails."""
+ await _provision_vm_background("uid123", "omi-agent-uid123", "omi-token")
+ mock_update.assert_called_once_with("uid123", None, "error")
+
+
+@pytest.mark.asyncio
+@patch("routers.agent_tools._update_firestore_vm")
+@patch("routers.agent_tools._start_vm_and_wait", new_callable=AsyncMock, side_effect=Exception("GCE start timed out"))
+async def test_restart_vm_background_sets_error_on_failure(mock_start, mock_update):
+ """_restart_vm_background sets Firestore status to 'error' when restart fails."""
+ await _restart_vm_background("uid123", "omi-agent-uid123", "us-central1-a")
+ mock_update.assert_called_once_with("uid123", None, "error")
+
+
+@pytest.mark.asyncio
+@patch("routers.agent_tools._set_firestore_vm")
+@patch("routers.agent_tools._create_gce_vm", new_callable=AsyncMock, return_value="10.0.0.1")
+async def test_provision_vm_background_sets_ready_on_success(mock_create, mock_set_fs):
+ """_provision_vm_background writes 'ready' status with IP on success."""
+ await _provision_vm_background("uid123", "omi-agent-uid123", "omi-token")
+ mock_set_fs.assert_called_once_with("uid123", "omi-agent-uid123", GCE_ZONE, "10.0.0.1", "ready", "omi-token")
+
+
+# --------------- incomplete Firestore payload tests ---------------
+
+
+@patch("routers.agent_tools.get_agent_vm", return_value={"status": "ready"})
+def test_vm_status_handles_missing_vm_name(mock_get):
+ """vm-status does not crash when vmName is missing from Firestore (skips GCE check)."""
+ resp = client.get("/v1/agent/vm-status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
+ assert data["vm_name"] is None
+
+
+@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING")
+@patch("routers.agent_tools.get_agent_vm", return_value={"vmName": "omi-agent-x", "status": "ready"})
+def test_vm_status_handles_missing_ip_and_auth(mock_get, mock_gce):
+ """vm-status returns None for ip and auth_token when missing from Firestore."""
+ resp = client.get("/v1/agent/vm-status")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["ip"] is None
+ assert data["auth_token"] is None
+ assert data["vm_name"] == "omi-agent-x"
+
+
+@patch("routers.agent_tools.get_agent_vm", return_value={})
+def test_vm_ensure_handles_empty_firestore_vm(mock_get):
+ """vm-ensure with empty Firestore dict (no status, falls through to _vm_response)."""
+ resp = client.post("/v1/agent/vm-ensure")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["has_vm"] is True
diff --git a/backend/tests/unit/test_assistant_settings_ai_profile.py b/backend/tests/unit/test_assistant_settings_ai_profile.py
new file mode 100644
index 0000000000..08908620a6
--- /dev/null
+++ b/backend/tests/unit/test_assistant_settings_ai_profile.py
@@ -0,0 +1,217 @@
+from datetime import datetime, timezone
+from unittest.mock import patch, MagicMock
+
+import pytest
+from fastapi.testclient import TestClient
+
+
+@pytest.fixture
+def client():
+ with patch('database.screen_activity.db'), \
+ patch('database.vector_db.Pinecone'), \
+ patch('database.vector_db.pc'), \
+ patch('database.vector_db.index'), \
+ patch('utils.llm.clients.embeddings'):
+ from main import app
+ with TestClient(app) as c:
+ yield c
+
+
+AUTH = {"Authorization": "Bearer 123testuser"}
+
+
+class TestAssistantSettingsValidation:
+ def test_get_empty_returns_200(self, client):
+ with patch('routers.users.get_assistant_settings', return_value={}):
+ resp = client.get("/v1/users/assistant-settings", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() == {}
+
+ def test_patch_prompt_exceeds_10000_chars(self, client):
+ data = {"focus": {"analysis_prompt": "x" * 10001}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "10000" in resp.json()["detail"]
+
+ def test_patch_prompt_at_10000_chars_accepted(self, client):
+ data = {"focus": {"analysis_prompt": "x" * 10000}}
+ with patch('routers.users.update_assistant_settings', return_value=data):
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_allowed_apps_exceeds_500(self, client):
+ data = {"task": {"allowed_apps": ["app"] * 501}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "500" in resp.json()["detail"]
+
+ def test_patch_browser_keywords_exceeds_500(self, client):
+ data = {"task": {"browser_keywords": ["kw"] * 501}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_task_confidence_below_zero(self, client):
+ data = {"task": {"min_confidence": -0.1}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_task_confidence_above_one(self, client):
+ data = {"task": {"min_confidence": 1.5}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_task_confidence_zero_accepted(self, client):
+ data = {"task": {"min_confidence": 0.0}}
+ with patch('routers.users.update_assistant_settings', return_value=data):
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_task_confidence_one_accepted(self, client):
+ data = {"task": {"min_confidence": 1.0}}
+ with patch('routers.users.update_assistant_settings', return_value=data):
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_advice_confidence_above_one(self, client):
+ data = {"advice": {"min_confidence": 1.1}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_memory_confidence_below_zero(self, client):
+ data = {"memory": {"min_confidence": -0.5}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_empty_body_returns_current(self, client):
+ with patch('routers.users.get_assistant_settings', return_value={"focus": {"enabled": True}}):
+ resp = client.patch("/v1/users/assistant-settings", json={}, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_task_prompt_exceeds_10000_chars(self, client):
+ data = {"task": {"analysis_prompt": "x" * 10001}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "task" in resp.json()["detail"]
+
+ def test_patch_advice_prompt_exceeds_10000_chars(self, client):
+ data = {"advice": {"analysis_prompt": "x" * 10001}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "advice" in resp.json()["detail"]
+
+ def test_patch_memory_prompt_exceeds_10000_chars(self, client):
+ data = {"memory": {"analysis_prompt": "x" * 10001}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "memory" in resp.json()["detail"]
+
+ def test_patch_advice_confidence_below_zero(self, client):
+ data = {"advice": {"min_confidence": -0.1}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_memory_confidence_above_one(self, client):
+ data = {"memory": {"min_confidence": 1.5}}
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_excludes_none_fields(self, client):
+ data = {"task": {"enabled": True}}
+ with patch('routers.users.update_assistant_settings', return_value=data) as mock_update:
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ call_data = mock_update.call_args[0][1]
+ assert "min_confidence" not in call_data.get("task", {})
+
+ def test_patch_update_channel(self, client):
+ data = {"update_channel": "beta"}
+ with patch('routers.users.update_assistant_settings', return_value={"update_channel": "beta"}) as mock_update:
+ resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ call_data = mock_update.call_args[0][1]
+ assert call_data["update_channel"] == "beta"
+
+
+class TestAIProfileValidation:
+ def test_get_empty_returns_null(self, client):
+ with patch('routers.users.get_ai_user_profile', return_value=None):
+ resp = client.get("/v1/users/ai-profile", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() is None
+
+ def test_patch_valid_rfc3339_z(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1}
+ with patch('routers.users.update_ai_user_profile', return_value=data):
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_valid_rfc3339_offset(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00+05:30", "data_sources_used": 1}
+ with patch('routers.users.update_ai_user_profile', return_value=data):
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_valid_rfc3339_fractional(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00.123Z", "data_sources_used": 1}
+ with patch('routers.users.update_ai_user_profile', return_value=data):
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 200
+
+ def test_patch_invalid_no_timezone(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00", "data_sources_used": 1}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_invalid_no_t_separator(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05 10:00:00Z", "data_sources_used": 1}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_invalid_short_offset(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00+00", "data_sources_used": 1}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_invalid_garbage(self, client):
+ data = {"profile_text": "test", "generated_at": "not-a-date", "data_sources_used": 1}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_invalid_calendar_date(self, client):
+ # Feb 30 passes regex but fails fromisoformat
+ data = {"profile_text": "test", "generated_at": "2026-02-30T10:00:00Z", "data_sources_used": 1}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_patch_generated_at_stored_as_datetime(self, client):
+ data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1}
+ with patch('routers.users.update_ai_user_profile') as mock_update:
+ mock_update.return_value = {}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ call_data = mock_update.call_args[0][1]
+ assert isinstance(call_data['generated_at'], datetime)
+ assert call_data['generated_at'].tzinfo is not None
+
+ def test_profile_text_truncation_at_boundary(self, client):
+ # 10001 bytes of ASCII should truncate to 10000
+ long_text = "x" * 10001
+ data = {"profile_text": long_text, "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1}
+ with patch('routers.users.update_ai_user_profile') as mock_update:
+ mock_update.return_value = {"profile_text": "x" * 10000}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ call_data = mock_update.call_args[0][1]
+ assert len(call_data['profile_text']) == 10000
+
+ def test_profile_text_multibyte_truncation(self, client):
+ # Multibyte UTF-8: emoji is 4 bytes, test boundary doesn't split mid-char
+ text = "a" * 9998 + "\U0001F600" # 9998 + 4 bytes = 10002 bytes
+ data = {"profile_text": text, "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1}
+ with patch('routers.users.update_ai_user_profile') as mock_update:
+ mock_update.return_value = {}
+ resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ call_data = mock_update.call_args[0][1]
+ # Should not have broken emoji — truncated to 9998 'a's
+ assert len(call_data['profile_text'].encode('utf-8')) <= 10000
diff --git a/backend/tests/unit/test_auth_routes.py b/backend/tests/unit/test_auth_routes.py
new file mode 100644
index 0000000000..68d4cc4c88
--- /dev/null
+++ b/backend/tests/unit/test_auth_routes.py
@@ -0,0 +1,242 @@
+"""Tests for auth endpoint redirect_uri validation and callback template rendering."""
+import sys
+import os
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from httpx import ASGITransport, AsyncClient
+
+os.environ.setdefault(
+ "ENCRYPTION_SECRET",
+ "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv",
+)
+
+# Stub heavy dependencies before importing the module under test
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('firebase_admin.messaging', MagicMock())
+sys.modules.setdefault('google.cloud', MagicMock())
+sys.modules.setdefault('google.cloud.firestore', MagicMock())
+sys.modules.setdefault('google.cloud.firestore_v1', MagicMock())
+sys.modules.setdefault('google.auth', MagicMock())
+sys.modules.setdefault('google.auth.transport.requests', MagicMock())
+
+from fastapi import FastAPI
+
+from routers.auth import router as auth_router
+
+# Minimal test app mounting only the auth router
+_test_app = FastAPI()
+_test_app.include_router(auth_router)
+
+
+# --- /v1/auth/authorize redirect_uri validation ---
+
+class TestAuthorizeRedirectUriValidation:
+ """Tests for redirect_uri allowlist at /v1/auth/authorize."""
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("bad_uri", [
+ "https://evil.com/steal",
+ "javascript:alert(1)",
+ "data:text/html,",
+ "ftp://example.com",
+ "",
+ ])
+ async def test_rejects_disallowed_redirect_uri(self, bad_uri):
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.get(
+ "/v1/auth/authorize",
+ params={"provider": "google", "redirect_uri": bad_uri, "state": "test"},
+ )
+ assert resp.status_code == 400
+ assert "allowed app URL scheme" in resp.json()["detail"]
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("good_uri", [
+ "omi://auth/callback",
+ "omi-computer://auth/callback",
+ "omi-computer-dev://auth/callback",
+ ])
+ @patch("routers.auth.set_auth_session")
+ async def test_accepts_allowed_redirect_schemes(self, mock_set_session, good_uri):
+ with patch("routers.auth.os.getenv") as mock_getenv:
+ mock_getenv.side_effect = lambda key, *args: {
+ "GOOGLE_CLIENT_ID": "test-client-id",
+ "GOOGLE_CLIENT_SECRET": "test-secret",
+ "BASE_API_URL": "https://api.omi.me",
+ "APPLE_CLIENT_ID": "me.omi.web",
+ "APPLE_TEAM_ID": "TEST",
+ "APPLE_KEY_ID": "TEST",
+ "APPLE_PRIVATE_KEY": "TEST",
+ }.get(key, args[0] if args else None)
+
+ async with AsyncClient(
+ transport=ASGITransport(app=_test_app),
+ base_url="http://test",
+ follow_redirects=False,
+ ) as client:
+ resp = await client.get(
+ "/v1/auth/authorize",
+ params={"provider": "google", "redirect_uri": good_uri, "state": "test123"},
+ )
+ # Should redirect to Google OAuth (307) or return 200, not 400
+ assert resp.status_code != 400
+ # Verify session was stored with the redirect_uri
+ mock_set_session.assert_called_once()
+ session_data = mock_set_session.call_args[0][1]
+ assert session_data["redirect_uri"] == good_uri
+
+ @pytest.mark.asyncio
+ async def test_rejects_missing_redirect_uri(self):
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.get(
+ "/v1/auth/authorize",
+ params={"provider": "google", "state": "test"},
+ )
+ # FastAPI returns 422 for missing required query param
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_rejects_invalid_provider(self):
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.get(
+ "/v1/auth/authorize",
+ params={"provider": "github", "redirect_uri": "omi://auth/callback"},
+ )
+ assert resp.status_code == 400
+ assert "Unsupported provider" in resp.json()["detail"]
+
+
+# --- Google callback template rendering ---
+
+class TestGoogleCallbackRedirectUri:
+ """Tests for redirect_uri in Google OAuth callback template."""
+
+ @pytest.mark.asyncio
+ @patch("routers.auth.get_auth_session")
+ @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock)
+ @patch("routers.auth.set_auth_code")
+ async def test_uses_session_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session):
+ mock_get_session.return_value = {
+ "provider": "google",
+ "redirect_uri": "omi-computer://auth/callback",
+ "state": "test_state",
+ "flow_type": "user_auth",
+ }
+ mock_exchange.return_value = '{"id_token": "test"}'
+
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.get(
+ "/v1/auth/callback/google",
+ params={"code": "test_code", "state": "test_state"},
+ )
+ assert resp.status_code == 200
+ body = resp.text
+ # Template should contain the desktop redirect scheme
+ assert "omi-computer://auth/callback" in body
+
+ @pytest.mark.asyncio
+ @patch("routers.auth.get_auth_session")
+ @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock)
+ @patch("routers.auth.set_auth_code")
+ async def test_falls_back_to_default_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session):
+ mock_get_session.return_value = {
+ "provider": "google",
+ "state": "test_state",
+ "flow_type": "user_auth",
+ # No redirect_uri in session
+ }
+ mock_exchange.return_value = '{"id_token": "test"}'
+
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.get(
+ "/v1/auth/callback/google",
+ params={"code": "test_code", "state": "test_state"},
+ )
+ assert resp.status_code == 200
+ body = resp.text
+ # Should fall back to omi:// scheme
+ assert "omi://auth/callback" in body
+
+
+# --- Apple callback template rendering ---
+
+class TestAppleCallbackRedirectUri:
+ """Tests for redirect_uri in Apple OAuth callback template."""
+
+ @pytest.mark.asyncio
+ @patch("routers.auth.get_auth_session")
+ @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock)
+ @patch("routers.auth.set_auth_code")
+ async def test_uses_session_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session):
+ mock_get_session.return_value = {
+ "provider": "apple",
+ "redirect_uri": "omi-computer://auth/callback",
+ "state": "test_state",
+ "flow_type": "user_auth",
+ }
+ mock_exchange.return_value = '{"id_token": "test"}'
+
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.post(
+ "/v1/auth/callback/apple",
+ data={"code": "test_code", "state": "test_state"},
+ )
+ assert resp.status_code == 200
+ body = resp.text
+ assert "omi-computer://auth/callback" in body
+
+ @pytest.mark.asyncio
+ @patch("routers.auth.get_auth_session")
+ @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock)
+ @patch("routers.auth.set_auth_code")
+ async def test_falls_back_to_default_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session):
+ mock_get_session.return_value = {
+ "provider": "apple",
+ "state": "test_state",
+ "flow_type": "user_auth",
+ }
+ mock_exchange.return_value = '{"id_token": "test"}'
+
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.post(
+ "/v1/auth/callback/apple",
+ data={"code": "test_code", "state": "test_state"},
+ )
+ assert resp.status_code == 200
+ body = resp.text
+ assert "omi://auth/callback" in body
+
+
+# --- Template XSS safety ---
+
+class TestCallbackTemplateXssSafety:
+ """Verify that redirect_uri is safely serialized in the callback template."""
+
+ @pytest.mark.asyncio
+ @patch("routers.auth.get_auth_session")
+ @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock)
+ @patch("routers.auth.set_auth_code")
+ async def test_redirect_uri_json_escaped(self, mock_set_code, mock_exchange, mock_get_session):
+ # Use a redirect_uri with quotes to test JSON escaping
+ mock_get_session.return_value = {
+ "provider": "google",
+ "redirect_uri": 'omi://auth/callback"test',
+ "state": "test_state",
+ "flow_type": "user_auth",
+ }
+ mock_exchange.return_value = '{"id_token": "test"}'
+
+ async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client:
+ resp = await client.get(
+ "/v1/auth/callback/google",
+ params={"code": "test_code", "state": "test_state"},
+ )
+ assert resp.status_code == 200
+ body = resp.text
+ # The quote should be JSON-escaped, not raw
+ assert r'omi://auth/callback\"test' in body or r'omi:\/\/auth\/callback\"test' in body
+ # Should NOT contain unescaped quote that breaks out of the JS string
+ assert 'const redirectUri = "omi://auth/callback"test"' not in body
diff --git a/backend/tests/unit/test_chat_generate_title.py b/backend/tests/unit/test_chat_generate_title.py
new file mode 100644
index 0000000000..77232e83a1
--- /dev/null
+++ b/backend/tests/unit/test_chat_generate_title.py
@@ -0,0 +1,209 @@
+import sys
+from datetime import datetime, timezone
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+for mod_name in [
+ 'firebase_admin',
+ 'firebase_admin.auth',
+ 'firebase_admin.firestore',
+ 'firebase_admin.messaging',
+ 'google.cloud',
+ 'google.cloud.exceptions',
+ 'google.cloud.firestore',
+ 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query',
+ 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage',
+ 'google.cloud.storage.blob',
+ 'google.cloud.storage.bucket',
+ 'google.auth',
+ 'google.auth.transport',
+ 'google.auth.transport.requests',
+ 'google.oauth2',
+ 'google.oauth2.service_account',
+ 'pinecone',
+ 'typesense',
+ 'openai',
+ 'langchain_openai',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+# Mock llm_mini before importing the router
+mock_llm = MagicMock()
+mock_llm.invoke.return_value = MagicMock(content='Project Discussion')
+sys.modules.setdefault('utils.llm.clients', MagicMock(llm_mini=mock_llm))
+
+from routers.chat import router
+
+from fastapi import FastAPI, HTTPException
+from fastapi.testclient import TestClient
+
+
+@pytest.fixture
+def client():
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+@pytest.fixture
+def client_no_auth():
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+AUTH = {"Authorization": "Bearer 123testuser"}
+
+
+class TestGenerateChatTitle:
+ def test_generate_title_success(self, client):
+ data = {
+ "session_id": "sess-1",
+ "messages": [
+ {"text": "How do I deploy to production?", "sender": "human"},
+ {"text": "You can use the CI/CD pipeline.", "sender": "ai"},
+ ],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='Production Deployment')
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["title"] == "Production Deployment"
+
+ def test_generate_title_strips_quotes(self, client):
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": "Hello", "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='"Greeting Chat"')
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["title"] == "Greeting Chat"
+
+ def test_generate_title_empty_messages_returns_400(self, client):
+ data = {"session_id": "sess-1", "messages": []}
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 400
+
+ def test_generate_title_no_messages_field_returns_422(self, client):
+ data = {"session_id": "sess-1"}
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 422
+
+ def test_generate_title_llm_fallback(self, client):
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": "What about the budget proposal?", "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.side_effect = Exception("LLM down")
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["title"] == "What about the budget proposal?"
+
+ def test_generate_title_updates_session(self, client):
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": "Hello", "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='Greeting')
+ with patch('routers.chat.chat_db.update_chat_session') as mock_update:
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ mock_update.assert_called_once()
+ call_args = mock_update.call_args[0]
+ assert call_args[1] == 'sess-1'
+ assert call_args[2]['title'] == 'Greeting'
+
+ def test_generate_title_session_update_failure_still_returns(self, client):
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": "Hello", "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='Greeting')
+ with patch('routers.chat.chat_db.update_chat_session', side_effect=Exception("DB err")):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["title"] == "Greeting"
+
+ def test_generate_title_truncates_long_title(self, client):
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": "Hello", "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='A' * 200)
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert len(resp.json()["title"]) <= 100
+
+ def test_generate_title_no_auth_returns_401(self, client_no_auth):
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": "Hello", "sender": "human"}],
+ }
+ with patch(
+ 'routers.chat.auth.get_current_user_uid',
+ side_effect=HTTPException(status_code=401, detail='Not authenticated'),
+ ):
+ resp = client_no_auth.post("/v2/chat/generate-title", json=data)
+ assert resp.status_code == 401
+
+ def test_generate_title_limits_messages(self, client):
+ """Only first 10 messages should be sent to LLM."""
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": f"Message {i}", "sender": "human"} for i in range(20)],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='Long Chat')
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ prompt = mock_llm.invoke.call_args[0][0]
+ assert 'Message 9' in prompt
+ assert 'Message 10' not in prompt
+
+ def test_generate_title_fallback_truncates_to_50_chars(self, client):
+ """When LLM fails, fallback title is truncated to 50 chars."""
+ long_text = 'A' * 100
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": long_text, "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.side_effect = Exception("LLM down")
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert len(resp.json()["title"]) == 50
+
+ def test_generate_title_truncates_message_text_to_500_chars(self, client):
+ """Each message text is truncated to 500 chars in the transcript sent to LLM."""
+ long_text = 'B' * 1000
+ data = {
+ "session_id": "sess-1",
+ "messages": [{"text": long_text, "sender": "human"}],
+ }
+ with patch('routers.chat.llm_mini') as mock_llm:
+ mock_llm.invoke.return_value = MagicMock(content='Title')
+ with patch('routers.chat.chat_db.update_chat_session'):
+ resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ prompt = mock_llm.invoke.call_args[0][0]
+ # The transcript line should contain exactly 500 B's, not 1000
+ assert 'B' * 500 in prompt
+ assert 'B' * 501 not in prompt
diff --git a/backend/tests/unit/test_conversations_count.py b/backend/tests/unit/test_conversations_count.py
new file mode 100644
index 0000000000..5ac7767ace
--- /dev/null
+++ b/backend/tests/unit/test_conversations_count.py
@@ -0,0 +1,149 @@
+import sys
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+for mod_name in [
+ 'firebase_admin',
+ 'firebase_admin.auth',
+ 'firebase_admin.firestore',
+ 'firebase_admin.messaging',
+ 'google.cloud',
+ 'google.cloud.exceptions',
+ 'google.cloud.firestore',
+ 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query',
+ 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage',
+ 'google.cloud.storage.blob',
+ 'google.cloud.storage.bucket',
+ 'google.auth',
+ 'google.auth.transport',
+ 'google.auth.transport.requests',
+ 'google.oauth2',
+ 'google.oauth2.service_account',
+ 'pinecone',
+ 'typesense',
+ 'openai',
+ 'langchain_openai',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+from routers.conversations import router
+
+from fastapi import FastAPI, HTTPException
+from fastapi.testclient import TestClient
+
+
+@pytest.fixture
+def client():
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='uid-1'):
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+@pytest.fixture
+def client_no_auth():
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+AUTH = {"Authorization": "Bearer 123testuser"}
+
+
+class TestConversationsCount:
+ def test_count_default_statuses(self, client):
+ with patch('routers.conversations.conversations_db.count_conversations', return_value=42) as mock_count:
+ resp = client.get("/v1/conversations/count", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 42
+ args = mock_count.call_args
+ assert args[1]['statuses'] == ['processing', 'completed']
+
+ def test_count_custom_statuses(self, client):
+ with patch('routers.conversations.conversations_db.count_conversations', return_value=10) as mock_count:
+ resp = client.get("/v1/conversations/count?statuses=completed", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 10
+ assert mock_count.call_args[1]['statuses'] == ['completed']
+
+ def test_count_empty_statuses(self, client):
+ with patch('routers.conversations.conversations_db.count_conversations', return_value=0) as mock_count:
+ resp = client.get("/v1/conversations/count?statuses=", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 0
+ assert mock_count.call_args[1]['statuses'] == []
+
+ def test_count_zero(self, client):
+ with patch('routers.conversations.conversations_db.count_conversations', return_value=0):
+ resp = client.get("/v1/conversations/count", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 0
+
+ def test_count_fallback_on_aggregation_error(self, client):
+ """If Firestore count() aggregation fails, falls back to stream_conversations."""
+ with patch(
+ 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("aggregation err")
+ ):
+ with patch('routers.conversations.conversations_db.stream_conversations', return_value=iter([1, 2, 3])):
+ resp = client.get("/v1/conversations/count", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 3
+
+ def test_count_no_auth_returns_401(self, client_no_auth):
+ with patch(
+ 'routers.conversations.auth.get_current_user_uid',
+ side_effect=HTTPException(status_code=401, detail='Not authenticated'),
+ ):
+ resp = client_no_auth.get("/v1/conversations/count")
+ assert resp.status_code == 401
+
+ def test_count_multiple_statuses(self, client):
+ with patch('routers.conversations.conversations_db.count_conversations', return_value=25) as mock_count:
+ resp = client.get("/v1/conversations/count?statuses=processing,completed,in_progress", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 25
+ assert mock_count.call_args[1]['statuses'] == ['processing', 'completed', 'in_progress']
+
+ def test_count_too_many_statuses_returns_400(self, client):
+ statuses = ','.join(f'status{i}' for i in range(11))
+ resp = client.get(f"/v1/conversations/count?statuses={statuses}", headers=AUTH)
+ assert resp.status_code == 400
+ assert 'max 10' in resp.json()['detail']
+
+ def test_count_stream_fallback(self, client):
+ """Fallback uses stream_conversations for unbounded counting."""
+ with patch(
+ 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("no aggregation")
+ ):
+ with patch(
+ 'routers.conversations.conversations_db.stream_conversations', return_value=iter([1, 2, 3, 4, 5])
+ ):
+ resp = client.get("/v1/conversations/count", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 5
+
+ def test_count_statuses_whitespace_normalization(self, client):
+ """Parser strips whitespace and drops empty segments from statuses."""
+ with patch('routers.conversations.conversations_db.count_conversations', return_value=7) as mock_count:
+ resp = client.get("/v1/conversations/count?statuses= processing , , completed ", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 7
+ assert mock_count.call_args[1]['statuses'] == ['processing', 'completed']
+
+ def test_count_fallback_receives_parsed_statuses(self, client):
+ """Fallback stream_conversations receives the same parsed status list."""
+ with patch(
+ 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("err")
+ ):
+ with patch(
+ 'routers.conversations.conversations_db.stream_conversations', return_value=iter([1, 2])
+ ) as mock_stream:
+ resp = client.get("/v1/conversations/count?statuses=completed,processing", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["count"] == 2
+ assert mock_stream.call_args[1]['statuses'] == ['completed', 'processing']
diff --git a/backend/tests/unit/test_desktop_advice.py b/backend/tests/unit/test_desktop_advice.py
new file mode 100644
index 0000000000..e1699f3eb7
--- /dev/null
+++ b/backend/tests/unit/test_desktop_advice.py
@@ -0,0 +1,159 @@
+"""Tests for desktop advice handler (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+from utils.desktop.advice import (
+ AdviceResult,
+ ADVICE_SYSTEM_PROMPT,
+ _build_advice_context,
+ generate_advice,
+)
+from models.message_event import AdviceExtractedEvent
+
+
+class TestAdviceResultModel:
+ def test_advice_with_content(self):
+ r = AdviceResult(has_advice=True, content="Take a break", category="health", confidence=0.8)
+ assert r.has_advice is True
+ assert r.content == "Take a break"
+ assert r.category == "health"
+
+ def test_no_advice(self):
+ r = AdviceResult(has_advice=False, confidence=0.1)
+ assert r.has_advice is False
+ assert r.content is None
+ assert r.category is None
+
+ def test_confidence_bounds(self):
+ with pytest.raises(Exception):
+ AdviceResult(has_advice=True, confidence=2.0)
+
+
+class TestAdviceExtractedEvent:
+ def test_event_with_advice(self):
+ event = AdviceExtractedEvent(
+ frame_id="frame789",
+ advice={"content": "Try dark mode", "category": "productivity", "confidence": 0.7},
+ )
+ data = event.to_json()
+ assert data["type"] == "advice_extracted"
+ assert data["frame_id"] == "frame789"
+ assert data["advice"]["content"] == "Try dark mode"
+
+ def test_event_no_advice(self):
+ event = AdviceExtractedEvent(frame_id="frame789", advice=None)
+ data = event.to_json()
+ assert data["advice"] is None
+
+
+class TestBuildAdviceContext:
+ @patch('utils.desktop.advice.get_action_items')
+ @patch('utils.desktop.advice.get_user_goals')
+ def test_goals_and_tasks_in_context(self, mock_goals, mock_tasks):
+ mock_goals.return_value = [{'title': 'Ship v2'}]
+ mock_tasks.return_value = [{'description': 'Write tests'}]
+ ctx = _build_advice_context("uid1")
+ assert "Ship v2" in ctx
+ assert "Write tests" in ctx
+
+ @patch('utils.desktop.advice.get_action_items')
+ @patch('utils.desktop.advice.get_user_goals')
+ def test_empty_context(self, mock_goals, mock_tasks):
+ mock_goals.return_value = []
+ mock_tasks.return_value = []
+ ctx = _build_advice_context("uid1")
+ assert ctx == ""
+
+ @patch('utils.desktop.advice.get_action_items')
+ @patch('utils.desktop.advice.get_user_goals')
+ def test_graceful_on_errors(self, mock_goals, mock_tasks):
+ mock_goals.side_effect = Exception("DB error")
+ mock_tasks.side_effect = Exception("DB error")
+ ctx = _build_advice_context("uid1")
+ assert ctx == ""
+
+ @patch('utils.desktop.advice.get_action_items')
+ @patch('utils.desktop.advice.get_user_goals')
+ def test_goals_fallback_to_description(self, mock_goals, mock_tasks):
+ mock_goals.return_value = [{'description': 'Fallback goal'}]
+ mock_tasks.return_value = []
+ ctx = _build_advice_context("uid1")
+ assert "Fallback goal" in ctx
+
+
+class TestGenerateAdvice:
+ @patch('utils.desktop.advice._build_advice_context')
+ @patch('utils.desktop.advice.llm_gemini_flash')
+ def test_returns_advice(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=AdviceResult(
+ has_advice=True,
+ content="Consider using a linter",
+ category="productivity",
+ confidence=0.75,
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ generate_advice("uid1", "base64img", "VS Code", "main.py")
+ )
+ assert result["has_advice"] is True
+ assert result["advice"]["content"] == "Consider using a linter"
+ assert result["advice"]["category"] == "productivity"
+
+ @patch('utils.desktop.advice._build_advice_context')
+ @patch('utils.desktop.advice.llm_gemini_flash')
+ def test_no_advice(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=AdviceResult(has_advice=False, confidence=0.1)
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ generate_advice("uid1", "base64img")
+ )
+ assert result["has_advice"] is False
+ assert result["advice"] is None
+
+ @patch('utils.desktop.advice._build_advice_context')
+ @patch('utils.desktop.advice.llm_gemini_flash')
+ def test_includes_app_info(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=AdviceResult(has_advice=False, confidence=0.1)
+ )
+ asyncio.get_event_loop().run_until_complete(
+ generate_advice("uid1", "base64img", "Chrome", "Stack Overflow")
+ )
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ text_content = human_msg.content[0]["text"]
+ assert "Chrome" in text_content
+ assert "Stack Overflow" in text_content
+
+
+class TestAdviceSystemPrompt:
+ def test_includes_categories(self):
+ assert "productivity" in ADVICE_SYSTEM_PROMPT
+ assert "mistake_prevention" in ADVICE_SYSTEM_PROMPT
+ assert "health" in ADVICE_SYSTEM_PROMPT
+ assert "goal_alignment" in ADVICE_SYSTEM_PROMPT
+
+ def test_includes_tone_guidance(self):
+ assert "TONE" in ADVICE_SYSTEM_PROMPT
diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py
new file mode 100644
index 0000000000..00fe444042
--- /dev/null
+++ b/backend/tests/unit/test_desktop_chat.py
@@ -0,0 +1,390 @@
+"""Tests for desktop chat sessions CRUD + message rating endpoints."""
+import sys
+from unittest.mock import patch, MagicMock
+from datetime import datetime, timezone
+
+import pytest
+
+for mod_name in [
+ 'firebase_admin', 'firebase_admin.auth', 'firebase_admin.firestore', 'firebase_admin.messaging',
+ 'google.cloud', 'google.cloud.exceptions', 'google.cloud.firestore', 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query', 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage', 'google.cloud.storage.blob', 'google.cloud.storage.bucket',
+ 'google.auth', 'google.auth.transport', 'google.auth.transport.requests',
+ 'google.oauth2', 'google.oauth2.service_account',
+ 'pinecone', 'typesense',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+from routers.chat import (
+ CreateChatSessionRequest,
+ UpdateChatSessionRequest,
+ ChatSessionResponse,
+ SaveMessageRequest,
+ SaveMessageResponse,
+ RateMessageRequest,
+ StatusResponse,
+ router,
+)
+
+
+class TestChatSessionModels:
+ def test_create_request_defaults(self):
+ req = CreateChatSessionRequest()
+ assert req.title is None
+ assert req.app_id is None
+
+ def test_update_request_partial(self):
+ req = UpdateChatSessionRequest(title="New Title")
+ assert req.title == "New Title"
+ assert req.starred is None
+
+ def test_session_response(self):
+ now = datetime.now(timezone.utc)
+ resp = ChatSessionResponse(id="s1", title="Test", created_at=now, updated_at=now)
+ assert resp.message_count == 0
+ assert resp.starred is False
+
+ def test_save_message_request(self):
+ req = SaveMessageRequest(text="Hello", sender="human")
+ assert req.app_id is None
+ assert req.session_id is None
+
+ def test_rate_request(self):
+ req = RateMessageRequest(rating=1)
+ assert req.rating == 1
+ req2 = RateMessageRequest()
+ assert req2.rating is None
+
+
+class TestChatSessionEndpoints:
+ def _make_app(self):
+ from fastapi import FastAPI
+ app = FastAPI()
+ app.include_router(router)
+ return app
+
+ @pytest.fixture
+ def client(self):
+ from fastapi.testclient import TestClient
+ return TestClient(self._make_app())
+
+ def test_create_session(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.add_chat_session') as mock_add,
+ ):
+ mock_add.side_effect = lambda uid, data: data
+ response = client.post(
+ '/v2/chat-sessions',
+ json={'title': 'My Chat', 'app_id': None},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data['title'] == 'My Chat'
+ assert data['message_count'] == 0
+ assert data['starred'] is False
+ assert 'id' in data
+
+ def test_create_session_default_title(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.add_chat_session') as mock_add,
+ ):
+ mock_add.side_effect = lambda uid, data: data
+ response = client.post(
+ '/v2/chat-sessions',
+ json={},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert response.json()['title'] == 'New Chat'
+
+ def test_list_sessions(self, client):
+ now = datetime.now(timezone.utc)
+ mock_sessions = [
+ {'id': 's1', 'title': 'Chat 1', 'created_at': now, 'updated_at': now, 'message_count': 5, 'starred': False},
+ {'id': 's2', 'title': 'Chat 2', 'created_at': now, 'updated_at': now, 'message_count': 3, 'starred': True},
+ ]
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_sessions', return_value=mock_sessions),
+ ):
+ response = client.get('/v2/chat-sessions', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert len(data) == 2
+ assert data[0]['title'] == 'Chat 1'
+
+ def test_get_session(self, client):
+ now = datetime.now(timezone.utc)
+ mock_session = {'id': 's1', 'title': 'Chat', 'created_at': now, 'updated_at': now, 'message_count': 0, 'starred': False}
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_session_by_id', return_value=mock_session),
+ ):
+ response = client.get('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['id'] == 's1'
+
+ def test_get_session_not_found(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_session_by_id', return_value=None),
+ ):
+ response = client.get('/v2/chat-sessions/missing', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 404
+
+ def test_update_session_returns_full_session(self, client):
+ now = datetime.now(timezone.utc)
+ mock_session = {'id': 's1', 'title': 'Old', 'created_at': now, 'updated_at': now, 'message_count': 0, 'starred': False}
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_session_by_id', return_value=mock_session),
+ patch('routers.chat.chat_db.update_chat_session') as mock_update,
+ ):
+ response = client.patch(
+ '/v2/chat-sessions/s1',
+ json={'title': 'Renamed', 'starred': True},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data['title'] == 'Renamed'
+ assert data['starred'] is True
+ assert data['id'] == 's1'
+
+ def test_delete_session_cascades_messages(self, client):
+ now = datetime.now(timezone.utc)
+ mock_session = {'id': 's1', 'title': 'Del', 'created_at': now, 'updated_at': now}
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_session_by_id', return_value=mock_session),
+ patch('routers.chat.chat_db.delete_chat_session_messages') as mock_del_msgs,
+ patch('routers.chat.chat_db.delete_chat_session') as mock_del,
+ ):
+ response = client.delete('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert mock_del_msgs.called
+ assert mock_del_msgs.call_args[0][1] == 's1'
+ assert mock_del.called
+ assert mock_del.call_args[0][1] == 's1'
+
+
+class TestDesktopMessageEndpoints:
+ def _make_app(self):
+ from fastapi import FastAPI
+ app = FastAPI()
+ app.include_router(router)
+ return app
+
+ @pytest.fixture
+ def client(self):
+ from fastapi.testclient import TestClient
+ return TestClient(self._make_app())
+
+ def test_save_message(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.save_message') as mock_save,
+ patch('routers.chat.chat_db.add_message_to_chat_session'),
+ ):
+ mock_save.side_effect = lambda uid, data: data
+ response = client.post(
+ '/v2/messages/save',
+ json={'text': 'Hello', 'sender': 'human', 'session_id': 's1'},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert 'id' in data
+ assert 'created_at' in data
+
+ def test_save_message_empty_text_422(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.post(
+ '/v2/messages/save',
+ json={'text': ' ', 'sender': 'human'},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_save_message_invalid_sender_422(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.post(
+ '/v2/messages/save',
+ json={'text': 'Hello', 'sender': 'bot'},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_rate_message_thumbs_up(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.update_message_rating', return_value=True) as mock_rate,
+ ):
+ response = client.patch(
+ '/v2/messages/msg-1/rating',
+ json={'rating': 1},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_rate.called
+ assert mock_rate.call_args[0][1] == 'msg-1'
+ assert mock_rate.call_args[0][2] == 1
+
+ def test_rate_message_clear(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.update_message_rating', return_value=True) as mock_rate,
+ ):
+ response = client.patch(
+ '/v2/messages/msg-1/rating',
+ json={'rating': None},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_rate.called
+ assert mock_rate.call_args[0][1] == 'msg-1'
+ assert mock_rate.call_args[0][2] is None
+
+ def test_rate_message_thumbs_down(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.update_message_rating', return_value=True) as mock_rate,
+ ):
+ response = client.patch(
+ '/v2/messages/msg-1/rating',
+ json={'rating': -1},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_rate.call_args[0][2] == -1
+
+ def test_rate_message_not_found_404(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.update_message_rating', return_value=False),
+ ):
+ response = client.patch(
+ '/v2/messages/msg-missing/rating',
+ json={'rating': 1},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 404
+
+ def test_rate_message_invalid_value_422(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.patch(
+ '/v2/messages/msg-1/rating',
+ json={'rating': 5},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_update_session_not_found(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_session_by_id', return_value=None),
+ ):
+ response = client.patch(
+ '/v2/chat-sessions/missing',
+ json={'title': 'Renamed'},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 404
+
+ def test_delete_session_not_found(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_session_by_id', return_value=None),
+ ):
+ response = client.delete(
+ '/v2/chat-sessions/missing',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 404
+
+ def test_save_message_session_link_failure_still_succeeds(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.save_message') as mock_save,
+ patch('routers.chat.chat_db.add_message_to_chat_session', side_effect=Exception('Firestore error')),
+ ):
+ mock_save.side_effect = lambda uid, data: data
+ response = client.post(
+ '/v2/messages/save',
+ json={'text': 'Hello', 'sender': 'human', 'session_id': 's1'},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert 'id' in response.json()
+
+ def test_create_session_malformed_body_422(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.post(
+ '/v2/chat-sessions',
+ content=b'not json',
+ headers={'Authorization': 'Bearer test', 'Content-Type': 'application/json'},
+ )
+ assert response.status_code == 422
+
+ def test_list_sessions_limit_max_valid(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_sessions', return_value=[]),
+ ):
+ response = client.get(
+ '/v2/chat-sessions?limit=200',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+
+ def test_list_sessions_limit_over_max_422(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get(
+ '/v2/chat-sessions?limit=201',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_list_sessions_app_id_filter(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_sessions', return_value=[]) as mock_get,
+ ):
+ response = client.get(
+ '/v2/chat-sessions?app_id=my-app',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_get.call_args[1]['app_id'] == 'my-app'
+
+ def test_list_sessions_starred_filter(self, client):
+ with (
+ patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.chat.chat_db.get_chat_sessions', return_value=[]) as mock_get,
+ ):
+ response = client.get(
+ '/v2/chat-sessions?starred=true',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_get.call_args[1]['starred'] is True
+
+ def test_list_sessions_limit_validation(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get(
+ '/v2/chat-sessions?limit=0',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_list_sessions_offset_negative_validation(self, client):
+ with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get(
+ '/v2/chat-sessions?offset=-1',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
diff --git a/backend/tests/unit/test_desktop_focus.py b/backend/tests/unit/test_desktop_focus.py
new file mode 100644
index 0000000000..60b52c80af
--- /dev/null
+++ b/backend/tests/unit/test_desktop_focus.py
@@ -0,0 +1,382 @@
+"""Tests for desktop focus analysis (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+# Mock heavy dependencies before any project imports
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+# Now safe to import
+from utils.desktop.focus import FocusResult, FOCUS_SYSTEM_PROMPT, _build_context
+from models.message_event import FocusResultEvent
+
+# --- FocusResult model tests ---
+
+
+class TestFocusResultModel:
+ def test_focus_result_focused(self):
+ result = FocusResult(
+ status="focused",
+ app_or_site="VS Code",
+ description="Writing Python code",
+ message="Great focus!",
+ )
+ assert result.status == "focused"
+ assert result.app_or_site == "VS Code"
+ assert result.description == "Writing Python code"
+ assert result.message == "Great focus!"
+
+ def test_focus_result_distracted(self):
+ result = FocusResult(
+ status="distracted",
+ app_or_site="YouTube",
+ description="Watching videos",
+ message="Time to refocus!",
+ )
+ assert result.status == "distracted"
+ assert result.app_or_site == "YouTube"
+
+ def test_focus_result_message_optional(self):
+ result = FocusResult(
+ status="focused",
+ app_or_site="Terminal",
+ description="Running tests",
+ )
+ assert result.message is None
+
+ def test_focus_result_message_none_explicit(self):
+ result = FocusResult(
+ status="focused",
+ app_or_site="Terminal",
+ description="Running tests",
+ message=None,
+ )
+ assert result.message is None
+
+
+# --- FocusResultEvent tests ---
+
+
+class TestFocusResultEvent:
+ def test_focus_result_event_to_json(self):
+ event = FocusResultEvent(
+ frame_id="abc-123",
+ status="focused",
+ app_or_site="VS Code",
+ description="Writing code",
+ message="Keep it up!",
+ )
+ j = event.to_json()
+ assert j["type"] == "focus_result"
+ assert j["frame_id"] == "abc-123"
+ assert j["status"] == "focused"
+ assert j["app_or_site"] == "VS Code"
+ assert j["description"] == "Writing code"
+ assert j["message"] == "Keep it up!"
+ assert "event_type" not in j
+
+ def test_focus_result_event_null_message(self):
+ event = FocusResultEvent(
+ frame_id="def-456",
+ status="distracted",
+ app_or_site="Twitter",
+ description="Browsing feed",
+ )
+ j = event.to_json()
+ assert j["type"] == "focus_result"
+ assert j["message"] is None
+
+ def test_focus_result_event_default_type(self):
+ event = FocusResultEvent(
+ frame_id="x",
+ status="focused",
+ app_or_site="Code",
+ description="Working",
+ )
+ assert event.event_type == "focus_result"
+
+
+# --- Context building tests ---
+
+
+class TestBuildContext:
+ @patch('utils.desktop.focus.get_memories', return_value=[])
+ @patch('utils.desktop.focus.get_action_items', return_value=[])
+ @patch('utils.desktop.focus.get_user_goals', return_value=[])
+ def test_empty_context(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert result == ""
+
+ @patch('utils.desktop.focus.get_memories', return_value=[])
+ @patch('utils.desktop.focus.get_action_items', return_value=[])
+ @patch(
+ 'utils.desktop.focus.get_user_goals',
+ return_value=[
+ {"title": "Ship Phase 2"},
+ {"title": "Learn Rust"},
+ ],
+ )
+ def test_goals_in_context(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert "Active Goals:" in result
+ assert "Ship Phase 2" in result
+ assert "Learn Rust" in result
+
+ @patch('utils.desktop.focus.get_memories', return_value=[])
+ @patch(
+ 'utils.desktop.focus.get_action_items',
+ return_value=[
+ {"description": "Fix login bug"},
+ {"description": "Review PR #42"},
+ ],
+ )
+ @patch('utils.desktop.focus.get_user_goals', return_value=[])
+ def test_tasks_in_context(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert "Current Tasks:" in result
+ assert "Fix login bug" in result
+ assert "Review PR #42" in result
+
+ @patch(
+ 'utils.desktop.focus.get_memories',
+ return_value=[
+ {"structured": {"title": "Learned about WebSockets"}},
+ ],
+ )
+ @patch('utils.desktop.focus.get_action_items', return_value=[])
+ @patch('utils.desktop.focus.get_user_goals', return_value=[])
+ def test_memories_in_context(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert "Recent Memories:" in result
+ assert "Learned about WebSockets" in result
+
+ @patch('utils.desktop.focus.get_memories', side_effect=Exception("DB error"))
+ @patch('utils.desktop.focus.get_action_items', side_effect=Exception("DB error"))
+ @patch('utils.desktop.focus.get_user_goals', side_effect=Exception("DB error"))
+ def test_context_graceful_on_errors(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert result == ""
+
+ @patch('utils.desktop.focus.get_memories', return_value=[])
+ @patch('utils.desktop.focus.get_action_items', return_value=[])
+ @patch(
+ 'utils.desktop.focus.get_user_goals',
+ return_value=[
+ {"description": "Goal without title"},
+ ],
+ )
+ def test_goals_fallback_to_description(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert "Goal without title" in result
+
+ @patch(
+ 'utils.desktop.focus.get_memories',
+ return_value=[
+ {"content": "Memory without structured field"},
+ ],
+ )
+ @patch('utils.desktop.focus.get_action_items', return_value=[])
+ @patch('utils.desktop.focus.get_user_goals', return_value=[])
+ def test_memories_fallback_to_content(self, mock_goals, mock_tasks, mock_memories):
+ result = _build_context("test-uid")
+ assert "Memory without structured field" in result
+
+
+# --- analyze_focus integration tests ---
+
+
+class TestAnalyzeFocus:
+ @patch('utils.desktop.focus._build_context', return_value="")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_returns_result(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="focused",
+ app_or_site="VS Code",
+ description="Editing Python",
+ message="Nice work!",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ result = asyncio.get_event_loop().run_until_complete(
+ analyze_focus(uid="test", image_b64="base64data", app_name="VS Code", window_title="main.py")
+ )
+
+ assert result["status"] == "focused"
+ assert result["app_or_site"] == "VS Code"
+ assert result["description"] == "Editing Python"
+ assert result["message"] == "Nice work!"
+
+ @patch('utils.desktop.focus._build_context', return_value="Active Goals:\n- Ship code")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_includes_context_in_prompt(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="distracted",
+ app_or_site="Twitter",
+ description="Browsing",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data"))
+
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ prompt_text = human_msg.content[0]["text"]
+ assert "Active Goals:" in prompt_text
+
+ @patch('utils.desktop.focus._build_context', return_value="")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_includes_history(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="focused",
+ app_or_site="Terminal",
+ description="Running tests",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ asyncio.get_event_loop().run_until_complete(
+ analyze_focus(
+ uid="test",
+ image_b64="data",
+ history="1. [focused] VS Code: Writing code",
+ )
+ )
+
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ prompt_text = human_msg.content[0]["text"]
+ assert "Recent activity" in prompt_text
+
+ @patch('utils.desktop.focus._build_context', return_value="")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_includes_app_and_window(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="focused",
+ app_or_site="Safari",
+ description="Reading docs",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ asyncio.get_event_loop().run_until_complete(
+ analyze_focus(uid="test", image_b64="data", app_name="Safari", window_title="MDN Web Docs")
+ )
+
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ prompt_text = human_msg.content[0]["text"]
+ assert "Safari" in prompt_text
+ assert "MDN Web Docs" in prompt_text
+
+ @patch('utils.desktop.focus._build_context', return_value="")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_sends_image_as_base64(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="focused",
+ app_or_site="Code",
+ description="Coding",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="FAKE_BASE64_IMAGE"))
+
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ image_part = human_msg.content[1]
+ assert image_part["type"] == "image_url"
+ assert "FAKE_BASE64_IMAGE" in image_part["image_url"]["url"]
+
+ @patch('utils.desktop.focus._build_context', return_value="")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_sends_system_prompt(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="focused",
+ app_or_site="Code",
+ description="Coding",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data"))
+
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ system_msg = call_args[0]
+ assert FOCUS_SYSTEM_PROMPT in system_msg.content
+
+ @patch('utils.desktop.focus._build_context', return_value="")
+ @patch('utils.desktop.focus.llm_gemini_flash')
+ def test_analyze_focus_distracted_result(self, mock_llm, mock_ctx):
+ from utils.desktop.focus import analyze_focus
+
+ mock_parser = MagicMock()
+ mock_parser.ainvoke = AsyncMock(
+ return_value=FocusResult(
+ status="distracted",
+ app_or_site="Reddit",
+ description="Scrolling r/programming",
+ message="Back to work!",
+ )
+ )
+ mock_llm.with_structured_output.return_value = mock_parser
+
+ result = asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data"))
+
+ assert result["status"] == "distracted"
+ assert result["app_or_site"] == "Reddit"
+ assert result["message"] == "Back to work!"
+
+
+# --- System prompt content tests ---
+
+
+class TestFocusSystemPrompt:
+ def test_prompt_includes_focused_criteria(self):
+ assert "Code editors" in FOCUS_SYSTEM_PROMPT
+
+ def test_prompt_includes_distracted_criteria(self):
+ assert "YouTube" in FOCUS_SYSTEM_PROMPT
+ assert "Twitter" in FOCUS_SYSTEM_PROMPT
+
+ def test_prompt_warns_about_log_text(self):
+ assert "log text" in FOCUS_SYSTEM_PROMPT
+
+ def test_prompt_mentions_context_aware(self):
+ assert "CONTEXT-AWARE" in FOCUS_SYSTEM_PROMPT
+
+ def test_prompt_coaching_message_guidance(self):
+ assert "100 characters max" in FOCUS_SYSTEM_PROMPT
diff --git a/backend/tests/unit/test_desktop_live_notes.py b/backend/tests/unit/test_desktop_live_notes.py
new file mode 100644
index 0000000000..7969427ce2
--- /dev/null
+++ b/backend/tests/unit/test_desktop_live_notes.py
@@ -0,0 +1,99 @@
+"""Tests for desktop live notes handler (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+from utils.desktop.live_notes import (
+ LiveNoteResult,
+ LIVE_NOTES_SYSTEM_PROMPT,
+ generate_live_note,
+)
+from models.message_event import LiveNoteEvent
+
+
+class TestLiveNoteResultModel:
+ def test_note_with_text(self):
+ r = LiveNoteResult(text="Key decision: ship by Friday")
+ assert r.text == "Key decision: ship by Friday"
+
+ def test_empty_note(self):
+ r = LiveNoteResult(text="")
+ assert r.text == ""
+
+
+class TestLiveNoteEvent:
+ def test_event_structure(self):
+ event = LiveNoteEvent(text="Meeting note content")
+ data = event.to_json()
+ assert data["type"] == "live_note"
+ assert data["text"] == "Meeting note content"
+
+
+class TestGenerateLiveNote:
+ @patch('utils.desktop.live_notes.llm_mini')
+ def test_returns_note(self, mock_llm):
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=LiveNoteResult(text="- Decision: use Redis for caching")
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ generate_live_note("We decided to use Redis for caching the API responses")
+ )
+ assert result["text"] == "- Decision: use Redis for caching"
+
+ @patch('utils.desktop.live_notes.llm_mini')
+ def test_empty_result(self, mock_llm):
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text=""))
+ result = asyncio.get_event_loop().run_until_complete(
+ generate_live_note("um yeah so like um")
+ )
+ assert result["text"] == ""
+
+ @patch('utils.desktop.live_notes.llm_mini')
+ def test_includes_session_context(self, mock_llm):
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="note"))
+ asyncio.get_event_loop().run_until_complete(
+ generate_live_note("transcript text", session_context="Sprint planning")
+ )
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ assert "Sprint planning" in human_msg.content
+
+ @patch('utils.desktop.live_notes.llm_mini')
+ def test_sends_system_prompt(self, mock_llm):
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text=""))
+ asyncio.get_event_loop().run_until_complete(
+ generate_live_note("test text")
+ )
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ sys_msg = call_args[0]
+ assert LIVE_NOTES_SYSTEM_PROMPT in sys_msg.content
+
+
+class TestLiveNotesSystemPrompt:
+ def test_includes_condensation_rules(self):
+ assert "Condense" in LIVE_NOTES_SYSTEM_PROMPT
+
+ def test_includes_word_limit(self):
+ assert "200 words" in LIVE_NOTES_SYSTEM_PROMPT
+
+ def test_includes_preservation_rules(self):
+ assert "names" in LIVE_NOTES_SYSTEM_PROMPT
+ assert "decisions" in LIVE_NOTES_SYSTEM_PROMPT
diff --git a/backend/tests/unit/test_desktop_memories.py b/backend/tests/unit/test_desktop_memories.py
new file mode 100644
index 0000000000..158760e2c4
--- /dev/null
+++ b/backend/tests/unit/test_desktop_memories.py
@@ -0,0 +1,150 @@
+"""Tests for desktop memory extraction handler (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+from utils.desktop.memories import (
+ ExtractedMemory,
+ MemoryExtractionResult,
+ MEMORY_SYSTEM_PROMPT,
+ _build_memory_context,
+ extract_memories,
+)
+from models.message_event import MemoriesExtractedEvent
+
+
+class TestExtractedMemoryModel:
+ def test_memory_all_fields(self):
+ m = ExtractedMemory(content="User prefers dark mode", category="system", confidence=0.95)
+ assert m.content == "User prefers dark mode"
+ assert m.category == "system"
+ assert m.confidence == 0.95
+
+ def test_memory_interesting_category(self):
+ m = ExtractedMemory(content="AI tip from article", category="interesting", confidence=0.7)
+ assert m.category == "interesting"
+
+ def test_confidence_bounds(self):
+ with pytest.raises(Exception):
+ ExtractedMemory(content="test", category="system", confidence=1.5)
+
+
+class TestMemoryExtractionResult:
+ def test_result_with_memories(self):
+ result = MemoryExtractionResult(
+ memories=[ExtractedMemory(content="Fact 1", category="system", confidence=0.8)]
+ )
+ assert len(result.memories) == 1
+
+ def test_result_empty(self):
+ result = MemoryExtractionResult()
+ assert result.memories == []
+
+
+class TestMemoriesExtractedEvent:
+ def test_event_structure(self):
+ event = MemoriesExtractedEvent(
+ frame_id="frame456",
+ memories=[{"content": "Test fact", "category": "system", "confidence": 0.9}],
+ )
+ data = event.to_json()
+ assert data["type"] == "memories_extracted"
+ assert data["frame_id"] == "frame456"
+ assert len(data["memories"]) == 1
+
+
+class TestBuildMemoryContext:
+ @patch('utils.desktop.memories.get_memories')
+ def test_existing_memories_in_context(self, mock_get):
+ mock_get.return_value = [
+ {'structured': {'content': 'User likes Python'}},
+ {'content': 'Fallback content'},
+ ]
+ ctx = _build_memory_context("uid1")
+ assert "User likes Python" in ctx
+ assert "Fallback content" in ctx
+ assert "DO NOT extract duplicates" in ctx
+
+ @patch('utils.desktop.memories.get_memories')
+ def test_empty_context(self, mock_get):
+ mock_get.return_value = []
+ ctx = _build_memory_context("uid1")
+ assert ctx == ""
+
+ @patch('utils.desktop.memories.get_memories')
+ def test_graceful_on_errors(self, mock_get):
+ mock_get.side_effect = Exception("DB error")
+ ctx = _build_memory_context("uid1")
+ assert ctx == ""
+
+
+class TestExtractMemories:
+ @patch('utils.desktop.memories._build_memory_context')
+ @patch('utils.desktop.memories.llm_gemini_flash')
+ def test_extract_returns_memories(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=MemoryExtractionResult(
+ memories=[
+ ExtractedMemory(content="User works on Omi project", category="system", confidence=0.85),
+ ]
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ extract_memories("uid1", "base64img", "VS Code", "omi/main.py")
+ )
+ assert len(result["memories"]) == 1
+ assert result["memories"][0]["content"] == "User works on Omi project"
+ assert result["memories"][0]["category"] == "system"
+
+ @patch('utils.desktop.memories._build_memory_context')
+ @patch('utils.desktop.memories.llm_gemini_flash')
+ def test_extract_empty_result(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(return_value=MemoryExtractionResult())
+ result = asyncio.get_event_loop().run_until_complete(
+ extract_memories("uid1", "base64img")
+ )
+ assert result["memories"] == []
+
+ @patch('utils.desktop.memories._build_memory_context')
+ @patch('utils.desktop.memories.llm_gemini_flash')
+ def test_sends_image_and_system_prompt(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(return_value=MemoryExtractionResult())
+ asyncio.get_event_loop().run_until_complete(
+ extract_memories("uid1", "testimg64")
+ )
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ sys_msg = call_args[0]
+ human_msg = call_args[1]
+ assert MEMORY_SYSTEM_PROMPT in sys_msg.content
+ assert human_msg.content[1]["image_url"]["url"] == "data:image/jpeg;base64,testimg64"
+
+
+class TestMemorySystemPrompt:
+ def test_includes_extraction_rules(self):
+ assert "EXTRACTION RULES" in MEMORY_SYSTEM_PROMPT
+
+ def test_includes_dedup(self):
+ assert "DEDUPLICATION" in MEMORY_SYSTEM_PROMPT
+
+ def test_includes_categories(self):
+ assert "system" in MEMORY_SYSTEM_PROMPT
+ assert "interesting" in MEMORY_SYSTEM_PROMPT
diff --git a/backend/tests/unit/test_desktop_profile.py b/backend/tests/unit/test_desktop_profile.py
new file mode 100644
index 0000000000..ea6fea4234
--- /dev/null
+++ b/backend/tests/unit/test_desktop_profile.py
@@ -0,0 +1,103 @@
+"""Tests for desktop profile generation handler (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+from utils.desktop.profile import (
+ ProfileResult,
+ PROFILE_SYSTEM_PROMPT,
+ generate_profile,
+)
+from models.message_event import ProfileUpdatedEvent
+
+
+class TestProfileResultModel:
+ def test_profile_text(self):
+ r = ProfileResult(profile_text="The user is a backend engineer focused on Python.")
+ assert "backend engineer" in r.profile_text
+
+
+class TestProfileUpdatedEvent:
+ def test_event_structure(self):
+ event = ProfileUpdatedEvent(profile_text="User profile text")
+ data = event.to_json()
+ assert data["type"] == "profile_updated"
+ assert data["profile_text"] == "User profile text"
+
+
+class TestGenerateProfile:
+ @patch('utils.desktop.profile.get_memories')
+ @patch('utils.desktop.profile.get_action_items')
+ @patch('utils.desktop.profile.get_user_goals')
+ @patch('utils.desktop.profile.llm_mini')
+ def test_generates_profile(self, mock_llm, mock_goals, mock_tasks, mock_memories):
+ mock_goals.return_value = [{'title': 'Ship v2'}]
+ mock_tasks.return_value = [{'description': 'Fix auth bug'}]
+ mock_memories.return_value = [{'structured': {'content': 'User prefers Python'}}]
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=ProfileResult(profile_text="The user is a developer focused on shipping v2.")
+ )
+ result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1"))
+ assert "developer" in result["profile_text"]
+
+ @patch('utils.desktop.profile.get_memories')
+ @patch('utils.desktop.profile.get_action_items')
+ @patch('utils.desktop.profile.get_user_goals')
+ def test_no_data_returns_default(self, mock_goals, mock_tasks, mock_memories):
+ mock_goals.return_value = []
+ mock_tasks.return_value = []
+ mock_memories.return_value = []
+ result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1"))
+ assert "No data available" in result["profile_text"]
+
+ @patch('utils.desktop.profile.get_memories')
+ @patch('utils.desktop.profile.get_action_items')
+ @patch('utils.desktop.profile.get_user_goals')
+ @patch('utils.desktop.profile.llm_mini')
+ def test_graceful_on_db_errors(self, mock_llm, mock_goals, mock_tasks, mock_memories):
+ mock_goals.side_effect = Exception("DB error")
+ mock_tasks.side_effect = Exception("DB error")
+ mock_memories.side_effect = Exception("DB error")
+ result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1"))
+ assert "No data available" in result["profile_text"]
+
+ @patch('utils.desktop.profile.get_memories')
+ @patch('utils.desktop.profile.get_action_items')
+ @patch('utils.desktop.profile.get_user_goals')
+ @patch('utils.desktop.profile.llm_mini')
+ def test_includes_goals_in_prompt(self, mock_llm, mock_goals, mock_tasks, mock_memories):
+ mock_goals.return_value = [{'title': 'Learn Rust'}]
+ mock_tasks.return_value = []
+ mock_memories.return_value = []
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=ProfileResult(profile_text="Profile text")
+ )
+ asyncio.get_event_loop().run_until_complete(generate_profile("uid1"))
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ assert "Learn Rust" in human_msg.content
+
+
+class TestProfileSystemPrompt:
+ def test_third_person_format(self):
+ assert "third person" in PROFILE_SYSTEM_PROMPT
+
+ def test_word_limit(self):
+ assert "300 words" in PROFILE_SYSTEM_PROMPT
+
+ def test_factual_requirement(self):
+ assert "factual" in PROFILE_SYSTEM_PROMPT
diff --git a/backend/tests/unit/test_desktop_task_ops.py b/backend/tests/unit/test_desktop_task_ops.py
new file mode 100644
index 0000000000..6cd9803a5f
--- /dev/null
+++ b/backend/tests/unit/test_desktop_task_ops.py
@@ -0,0 +1,175 @@
+"""Tests for desktop task operations (rerank + dedup) handlers (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+from utils.desktop.task_ops import (
+ RankedTask,
+ RerankResult,
+ DedupGroup,
+ DedupResult,
+ RERANK_SYSTEM_PROMPT,
+ DEDUP_SYSTEM_PROMPT,
+ rerank_tasks,
+ dedup_tasks,
+)
+from models.message_event import RerankCompleteEvent, DedupCompleteEvent
+
+
+# --- Rerank tests ---
+
+
+class TestRankedTaskModel:
+ def test_ranked_task(self):
+ t = RankedTask(id="task1", new_position=1)
+ assert t.id == "task1"
+ assert t.new_position == 1
+
+
+class TestRerankResult:
+ def test_rerank_result(self):
+ r = RerankResult(updated_tasks=[RankedTask(id="t1", new_position=1)])
+ assert len(r.updated_tasks) == 1
+
+
+class TestRerankCompleteEvent:
+ def test_event_structure(self):
+ event = RerankCompleteEvent(updated_tasks=[{"id": "t1", "new_position": 1}])
+ data = event.to_json()
+ assert data["type"] == "rerank_complete"
+ assert len(data["updated_tasks"]) == 1
+
+
+class TestRerankTasks:
+ @patch('utils.desktop.task_ops.get_action_items')
+ @patch('utils.desktop.task_ops.llm_mini')
+ def test_rerank_returns_order(self, mock_llm, mock_get):
+ mock_get.return_value = [
+ {'id': 't1', 'description': 'Low priority', 'priority': 'low'},
+ {'id': 't2', 'description': 'Urgent fix', 'priority': 'high', 'due_at': '2026-03-08'},
+ ]
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=RerankResult(
+ updated_tasks=[
+ RankedTask(id="t2", new_position=1),
+ RankedTask(id="t1", new_position=2),
+ ]
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1"))
+ assert result["updated_tasks"][0]["id"] == "t2"
+ assert result["updated_tasks"][0]["new_position"] == 1
+
+ @patch('utils.desktop.task_ops.get_action_items')
+ def test_rerank_empty_tasks(self, mock_get):
+ mock_get.return_value = []
+ result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1"))
+ assert result["updated_tasks"] == []
+
+ @patch('utils.desktop.task_ops.get_action_items')
+ def test_rerank_db_error(self, mock_get):
+ mock_get.side_effect = Exception("DB error")
+ result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1"))
+ assert result["updated_tasks"] == []
+
+
+# --- Dedup tests ---
+
+
+class TestDedupGroupModel:
+ def test_dedup_group(self):
+ g = DedupGroup(keep_id="t1", delete_ids=["t2", "t3"], reason="Same task")
+ assert g.keep_id == "t1"
+ assert len(g.delete_ids) == 2
+
+
+class TestDedupResult:
+ def test_dedup_with_groups(self):
+ r = DedupResult(groups=[DedupGroup(keep_id="t1", delete_ids=["t2"], reason="Duplicate")])
+ assert len(r.groups) == 1
+
+ def test_dedup_no_groups(self):
+ r = DedupResult()
+ assert r.groups == []
+
+
+class TestDedupCompleteEvent:
+ def test_event_structure(self):
+ event = DedupCompleteEvent(deleted_ids=["t2", "t3"], reason="Duplicate tasks")
+ data = event.to_json()
+ assert data["type"] == "dedup_complete"
+ assert data["deleted_ids"] == ["t2", "t3"]
+ assert data["reason"] == "Duplicate tasks"
+
+
+class TestDedupTasks:
+ @patch('utils.desktop.task_ops.get_action_items')
+ @patch('utils.desktop.task_ops.llm_mini')
+ def test_dedup_finds_duplicates(self, mock_llm, mock_get):
+ mock_get.return_value = [
+ {'id': 't1', 'description': 'Call John'},
+ {'id': 't2', 'description': 'Phone John'},
+ {'id': 't3', 'description': 'Write report'},
+ ]
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=DedupResult(
+ groups=[DedupGroup(keep_id="t1", delete_ids=["t2"], reason="Same action: contact John")]
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1"))
+ assert result["deleted_ids"] == ["t2"]
+ assert "contact John" in result["reason"]
+
+ @patch('utils.desktop.task_ops.get_action_items')
+ @patch('utils.desktop.task_ops.llm_mini')
+ def test_dedup_no_duplicates(self, mock_llm, mock_get):
+ mock_get.return_value = [
+ {'id': 't1', 'description': 'Task A'},
+ {'id': 't2', 'description': 'Task B'},
+ ]
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(return_value=DedupResult())
+ result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1"))
+ assert result["deleted_ids"] == []
+ assert result["reason"] == "No duplicates found"
+
+ @patch('utils.desktop.task_ops.get_action_items')
+ def test_dedup_too_few_tasks(self, mock_get):
+ mock_get.return_value = [{'id': 't1', 'description': 'Only one'}]
+ result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1"))
+ assert result["deleted_ids"] == []
+ assert "Not enough" in result["reason"]
+
+ @patch('utils.desktop.task_ops.get_action_items')
+ def test_dedup_db_error(self, mock_get):
+ mock_get.side_effect = Exception("DB error")
+ result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1"))
+ assert result["deleted_ids"] == []
+ assert "Failed" in result["reason"]
+
+
+class TestRerankSystemPrompt:
+ def test_includes_rules(self):
+ assert "RULES" in RERANK_SYSTEM_PROMPT
+ assert "deadlines" in RERANK_SYSTEM_PROMPT
+
+
+class TestDedupSystemPrompt:
+ def test_includes_rules(self):
+ assert "RULES" in DEDUP_SYSTEM_PROMPT
+ assert "duplicates" in DEDUP_SYSTEM_PROMPT.lower()
diff --git a/backend/tests/unit/test_desktop_tasks.py b/backend/tests/unit/test_desktop_tasks.py
new file mode 100644
index 0000000000..908b4d9a50
--- /dev/null
+++ b/backend/tests/unit/test_desktop_tasks.py
@@ -0,0 +1,238 @@
+"""Tests for desktop task extraction handler (Phase 2 — #5396)."""
+
+import asyncio
+import sys
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+# Mock heavy dependencies before any project imports
+sys.modules.setdefault('firebase_admin', MagicMock())
+sys.modules.setdefault('firebase_admin.auth', MagicMock())
+sys.modules.setdefault('firebase_admin.firestore', MagicMock())
+sys.modules.setdefault('database._client', MagicMock())
+_mock_clients = MagicMock()
+sys.modules.setdefault('utils.llm.clients', _mock_clients)
+
+from utils.desktop.tasks import (
+ ExtractedTask,
+ TaskExtractionResult,
+ TASK_SYSTEM_PROMPT,
+ _build_task_context,
+ extract_tasks,
+)
+from models.message_event import TasksExtractedEvent
+
+
+class TestExtractedTaskModel:
+ def test_task_with_all_fields(self):
+ task = ExtractedTask(
+ title="Review pull request 42 for authentication changes",
+ description="Check auth middleware",
+ priority="high",
+ tags=["code-review", "auth"],
+ source_app="GitHub",
+ inferred_deadline="2026-03-10",
+ confidence=0.9,
+ source_category="direct_request",
+ )
+ assert task.title == "Review pull request 42 for authentication changes"
+ assert task.priority == "high"
+ assert task.confidence == 0.9
+
+ def test_task_defaults(self):
+ task = ExtractedTask(
+ title="Update the README with new API docs",
+ priority="low",
+ confidence=0.5,
+ )
+ assert task.description == ""
+ assert task.tags == []
+ assert task.source_app == ""
+ assert task.inferred_deadline is None
+ assert task.source_category == "reactive"
+
+ def test_task_confidence_bounds(self):
+ with pytest.raises(Exception):
+ ExtractedTask(title="Test", priority="high", confidence=1.5)
+ with pytest.raises(Exception):
+ ExtractedTask(title="Test", priority="high", confidence=-0.1)
+
+
+class TestTaskExtractionResult:
+ def test_result_with_tasks(self):
+ result = TaskExtractionResult(
+ has_new_tasks=True,
+ tasks=[
+ ExtractedTask(title="Call John about the project deadline", priority="high", confidence=0.8),
+ ],
+ context_summary="Slack messages",
+ current_activity="Reading messages",
+ )
+ assert result.has_new_tasks is True
+ assert len(result.tasks) == 1
+
+ def test_result_no_tasks(self):
+ result = TaskExtractionResult(
+ has_new_tasks=False,
+ context_summary="IDE open",
+ current_activity="Coding",
+ )
+ assert result.has_new_tasks is False
+ assert result.tasks == []
+
+
+class TestTasksExtractedEvent:
+ def test_event_structure(self):
+ event = TasksExtractedEvent(
+ frame_id="frame123",
+ tasks=[{"title": "Test task", "priority": "high"}],
+ )
+ data = event.to_json()
+ assert data["type"] == "tasks_extracted"
+ assert data["frame_id"] == "frame123"
+ assert len(data["tasks"]) == 1
+
+
+class TestBuildTaskContext:
+ @patch('utils.desktop.tasks.get_action_items')
+ def test_active_tasks_in_context(self, mock_get):
+ mock_get.return_value = [
+ {'description': 'Write tests', 'due_at': '2026-03-10'},
+ {'description': 'Fix bug'},
+ ]
+ ctx = _build_task_context("uid1")
+ assert "Write tests" in ctx
+ assert "Due: 2026-03-10" in ctx
+ assert "Fix bug" in ctx
+ assert "Pending" in ctx
+
+ @patch('utils.desktop.tasks.get_action_items')
+ def test_completed_tasks_in_context(self, mock_get):
+ mock_get.side_effect = [
+ [], # active tasks
+ [{'description': 'Done task'}], # completed tasks
+ ]
+ ctx = _build_task_context("uid1")
+ assert "Done task" in ctx
+ assert "Completed" in ctx
+
+ @patch('utils.desktop.tasks.get_action_items')
+ def test_empty_context(self, mock_get):
+ mock_get.return_value = []
+ ctx = _build_task_context("uid1")
+ assert ctx == ""
+
+ @patch('utils.desktop.tasks.get_action_items')
+ def test_graceful_on_errors(self, mock_get):
+ mock_get.side_effect = Exception("DB error")
+ ctx = _build_task_context("uid1")
+ assert ctx == ""
+
+
+class TestExtractTasks:
+ @patch('utils.desktop.tasks._build_task_context')
+ @patch('utils.desktop.tasks.llm_gemini_flash')
+ def test_extract_tasks_returns_result(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=TaskExtractionResult(
+ has_new_tasks=True,
+ tasks=[
+ ExtractedTask(
+ title="Review pull request 42 for auth changes",
+ priority="high",
+ confidence=0.9,
+ source_app="GitHub",
+ )
+ ],
+ context_summary="GitHub PR page",
+ current_activity="Reviewing code",
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ extract_tasks("uid1", "base64img", "Chrome", "GitHub PR")
+ )
+ assert result["has_new_tasks"] is True
+ assert len(result["tasks"]) == 1
+ assert result["tasks"][0]["title"] == "Review pull request 42 for auth changes"
+ assert result["tasks"][0]["source_app"] == "GitHub"
+
+ @patch('utils.desktop.tasks._build_task_context')
+ @patch('utils.desktop.tasks.llm_gemini_flash')
+ def test_extract_tasks_no_tasks(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=TaskExtractionResult(
+ has_new_tasks=False,
+ context_summary="Desktop idle",
+ current_activity="Nothing",
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ extract_tasks("uid1", "base64img")
+ )
+ assert result["has_new_tasks"] is False
+ assert result["tasks"] == []
+
+ @patch('utils.desktop.tasks._build_task_context')
+ @patch('utils.desktop.tasks.llm_gemini_flash')
+ def test_source_app_fallback(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = ""
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=TaskExtractionResult(
+ has_new_tasks=True,
+ tasks=[
+ ExtractedTask(
+ title="Send email to team about deadline update",
+ priority="medium",
+ confidence=0.7,
+ source_app="", # empty
+ )
+ ],
+ )
+ )
+ result = asyncio.get_event_loop().run_until_complete(
+ extract_tasks("uid1", "base64img", "Slack", "General")
+ )
+ # Falls back to app_name when source_app is empty
+ assert result["tasks"][0]["source_app"] == "Slack"
+
+ @patch('utils.desktop.tasks._build_task_context')
+ @patch('utils.desktop.tasks.llm_gemini_flash')
+ def test_includes_context_in_prompt(self, mock_llm, mock_ctx):
+ mock_ctx.return_value = "Existing active tasks:\n- Write tests [Pending]"
+ mock_parser = MagicMock()
+ mock_llm.with_structured_output.return_value = mock_parser
+ mock_parser.ainvoke = AsyncMock(
+ return_value=TaskExtractionResult(has_new_tasks=False)
+ )
+ asyncio.get_event_loop().run_until_complete(
+ extract_tasks("uid1", "base64img", "VS Code", "main.py")
+ )
+ call_args = mock_parser.ainvoke.call_args[0][0]
+ human_msg = call_args[1]
+ text_content = human_msg.content[0]["text"]
+ assert "Write tests" in text_content
+ assert "VS Code" in text_content
+
+
+class TestTaskSystemPrompt:
+ def test_prompt_includes_dedup_rules(self):
+ assert "DEDUPLICATION" in TASK_SYSTEM_PROMPT
+
+ def test_prompt_includes_priority_guidelines(self):
+ assert "high" in TASK_SYSTEM_PROMPT
+ assert "medium" in TASK_SYSTEM_PROMPT
+ assert "low" in TASK_SYSTEM_PROMPT
+
+ def test_prompt_includes_source_categories(self):
+ assert "direct_request" in TASK_SYSTEM_PROMPT
+ assert "self_generated" in TASK_SYSTEM_PROMPT
+ assert "calendar_driven" in TASK_SYSTEM_PROMPT
diff --git a/backend/tests/unit/test_focus_sessions.py b/backend/tests/unit/test_focus_sessions.py
new file mode 100644
index 0000000000..bc78913fe3
--- /dev/null
+++ b/backend/tests/unit/test_focus_sessions.py
@@ -0,0 +1,237 @@
+import sys
+from datetime import datetime, timezone
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+for mod_name in [
+ 'firebase_admin',
+ 'firebase_admin.auth',
+ 'firebase_admin.firestore',
+ 'firebase_admin.messaging',
+ 'google.cloud',
+ 'google.cloud.exceptions',
+ 'google.cloud.firestore',
+ 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query',
+ 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage',
+ 'google.cloud.storage.blob',
+ 'google.cloud.storage.bucket',
+ 'google.auth',
+ 'google.auth.transport',
+ 'google.auth.transport.requests',
+ 'google.oauth2',
+ 'google.oauth2.service_account',
+ 'pinecone',
+ 'typesense',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+from routers.focus_sessions import router
+
+from fastapi import FastAPI, HTTPException
+from fastapi.testclient import TestClient
+
+
+@pytest.fixture
+def client():
+ with patch('routers.focus_sessions.auth.get_current_user_uid', return_value='uid-1'):
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+@pytest.fixture
+def client_no_auth():
+ """Client without auth mock — for testing 401 responses."""
+ app = FastAPI()
+ app.include_router(router)
+ with TestClient(app) as c:
+ yield c
+
+
+AUTH = {"Authorization": "Bearer 123testuser"}
+
+
+class TestCreateFocusSession:
+ def test_create_focused_session(self, client):
+ data = {"status": "focused", "app_or_site": "VSCode", "description": "Coding"}
+ with patch('routers.focus_sessions.focus_sessions_db.create_focus_session') as mock_create:
+ mock_create.return_value = {
+ "id": "abc-123", "status": "focused", "app_or_site": "VSCode",
+ "description": "Coding", "created_at": datetime.now(timezone.utc),
+ }
+ resp = client.post("/v1/focus-sessions", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["status"] == "focused"
+
+ def test_create_distracted_session(self, client):
+ data = {"status": "distracted", "app_or_site": "Twitter", "description": "Scrolling"}
+ with patch('routers.focus_sessions.focus_sessions_db.create_focus_session') as mock_create:
+ mock_create.return_value = {
+ "id": "abc-456", "status": "distracted", "app_or_site": "Twitter",
+ "description": "Scrolling", "created_at": datetime.now(timezone.utc),
+ }
+ resp = client.post("/v1/focus-sessions", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["status"] == "distracted"
+
+ def test_create_invalid_status_returns_400(self, client):
+ data = {"status": "invalid", "app_or_site": "X", "description": "Y"}
+ resp = client.post("/v1/focus-sessions", json=data, headers=AUTH)
+ assert resp.status_code == 400
+ assert "focused" in resp.json()["detail"]
+
+ def test_create_with_optional_fields(self, client):
+ data = {
+ "status": "focused", "app_or_site": "VSCode", "description": "Coding",
+ "message": "Keep going!", "duration_seconds": 300,
+ }
+ with patch('routers.focus_sessions.focus_sessions_db.create_focus_session') as mock_create:
+ mock_create.return_value = {
+ "id": "abc-789", "message": "Keep going!", "duration_seconds": 300,
+ **{k: v for k, v in data.items()}, "created_at": datetime.now(timezone.utc),
+ }
+ resp = client.post("/v1/focus-sessions", json=data, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["message"] == "Keep going!"
+ assert resp.json()["duration_seconds"] == 300
+
+ def test_create_no_auth_returns_401(self, client_no_auth):
+ data = {"status": "focused", "app_or_site": "X", "description": "Y"}
+ with patch(
+ 'routers.focus_sessions.auth.get_current_user_uid',
+ side_effect=HTTPException(status_code=401, detail='Not authenticated'),
+ ):
+ resp = client_no_auth.post("/v1/focus-sessions", json=data)
+ assert resp.status_code == 401
+
+ def test_create_firestore_error_returns_500(self, client):
+ data = {"status": "focused", "app_or_site": "X", "description": "Y"}
+ with patch('routers.focus_sessions.focus_sessions_db.create_focus_session', side_effect=Exception("DB down")):
+ resp = client.post("/v1/focus-sessions", json=data, headers=AUTH)
+ assert resp.status_code == 500
+
+
+class TestGetFocusSessions:
+ def test_get_empty_returns_list(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]):
+ resp = client.get("/v1/focus-sessions", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() == []
+
+ def test_get_with_date_filter(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get:
+ resp = client.get("/v1/focus-sessions?date=2026-03-05", headers=AUTH)
+ assert resp.status_code == 200
+ mock_get.assert_called_once()
+ assert mock_get.call_args[1]['date'] == '2026-03-05'
+
+ def test_get_invalid_date_skips_filter(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get:
+ resp = client.get("/v1/focus-sessions?date=not-a-date", headers=AUTH)
+ assert resp.status_code == 200
+ assert mock_get.call_args[1]['date'] is None
+
+ def test_get_with_limit_and_offset(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get:
+ resp = client.get("/v1/focus-sessions?limit=50&offset=10", headers=AUTH)
+ assert resp.status_code == 200
+ mock_get.assert_called_once()
+ assert mock_get.call_args[1]['limit'] == 50
+ assert mock_get.call_args[1]['offset'] == 10
+
+ def test_get_firestore_error_returns_empty(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', side_effect=Exception("err")):
+ resp = client.get("/v1/focus-sessions", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() == []
+
+
+class TestDeleteFocusSession:
+ def test_delete_returns_ok(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.delete_focus_session', return_value=True):
+ resp = client.delete("/v1/focus-sessions/abc-123", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["status"] == "ok"
+
+ def test_delete_firestore_error_returns_500(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.delete_focus_session', side_effect=Exception("err")):
+ resp = client.delete("/v1/focus-sessions/abc-123", headers=AUTH)
+ assert resp.status_code == 500
+
+
+class TestFocusStats:
+ def test_stats_empty_sessions(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=[]):
+ resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH)
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["date"] == "2026-03-05"
+ assert data["session_count"] == 0
+ assert data["focused_count"] == 0
+ assert data["distracted_count"] == 0
+ assert data["top_distractions"] == []
+
+ def test_stats_with_sessions(self, client):
+ sessions = [
+ {"status": "focused", "app_or_site": "VSCode", "duration_seconds": 120},
+ {"status": "distracted", "app_or_site": "Twitter", "duration_seconds": 60},
+ {"status": "distracted", "app_or_site": "Twitter", "duration_seconds": 90},
+ {"status": "distracted", "app_or_site": "Reddit", "duration_seconds": 30},
+ ]
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions):
+ resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH)
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["focused_count"] == 1
+ assert data["distracted_count"] == 3
+ assert data["session_count"] == 4
+ assert len(data["top_distractions"]) == 2
+ assert data["top_distractions"][0]["app_or_site"] == "Twitter"
+ assert data["top_distractions"][0]["total_seconds"] == 150
+
+ def test_stats_defaults_to_today(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=[]) as mock_get:
+ resp = client.get("/v1/focus-stats", headers=AUTH)
+ assert resp.status_code == 200
+ called_date = mock_get.call_args[0][1]
+ today = datetime.now(timezone.utc).strftime('%Y-%m-%d')
+ assert called_date == today
+
+ def test_stats_invalid_date_defaults_to_today(self, client):
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=[]) as mock_get:
+ resp = client.get("/v1/focus-stats?date=bad", headers=AUTH)
+ assert resp.status_code == 200
+ today = datetime.now(timezone.utc).strftime('%Y-%m-%d')
+ assert mock_get.call_args[0][1] == today
+
+ def test_stats_distraction_without_duration_defaults_60(self, client):
+ sessions = [
+ {"status": "distracted", "app_or_site": "YouTube"},
+ ]
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions):
+ resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["top_distractions"][0]["total_seconds"] == 60
+
+ def test_stats_distraction_with_zero_duration_keeps_zero(self, client):
+ sessions = [
+ {"status": "distracted", "app_or_site": "Slack", "duration_seconds": 0},
+ ]
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions):
+ resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["top_distractions"][0]["total_seconds"] == 0
+
+ def test_stats_top5_limit(self, client):
+ sessions = [
+ {"status": "distracted", "app_or_site": f"App{i}", "duration_seconds": i * 10}
+ for i in range(8)
+ ]
+ with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions):
+ resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH)
+ assert resp.status_code == 200
+ assert len(resp.json()["top_distractions"]) == 5
diff --git a/backend/tests/unit/test_from_segments.py b/backend/tests/unit/test_from_segments.py
new file mode 100644
index 0000000000..068e3134f9
--- /dev/null
+++ b/backend/tests/unit/test_from_segments.py
@@ -0,0 +1,301 @@
+"""Tests for POST /v1/conversations/from-segments endpoint models and validation."""
+import sys
+from unittest.mock import MagicMock
+from datetime import datetime, timezone, timedelta
+
+import pytest
+
+# Stub ALL heavy dependencies before any import that could transitively pull them in.
+# Order matters: stub parent packages before child packages.
+for mod_name in [
+ 'firebase_admin', 'firebase_admin.auth', 'firebase_admin.firestore', 'firebase_admin.messaging',
+ 'google.cloud', 'google.cloud.exceptions', 'google.cloud.firestore', 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query', 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage', 'google.cloud.storage.blob', 'google.cloud.storage.bucket',
+ 'google.auth', 'google.auth.transport', 'google.auth.transport.requests',
+ 'google.oauth2', 'google.oauth2.service_account',
+ 'pinecone',
+ 'typesense',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+from routers.conversations import (
+ FromSegmentsTranscriptSegment,
+ CreateConversationFromSegmentsRequest,
+ FromSegmentsResponse,
+)
+
+
+@pytest.fixture
+def valid_segments():
+ return [
+ FromSegmentsTranscriptSegment(text="Hello there", speaker="SPEAKER_00", is_user=True, start=0.0, end=2.5),
+ FromSegmentsTranscriptSegment(text="Hi, how are you?", speaker="SPEAKER_01", is_user=False, start=2.8, end=5.2),
+ ]
+
+
+class TestFromSegmentsModels:
+ def test_segment_defaults(self):
+ seg = FromSegmentsTranscriptSegment(text="Hello", start=0.0, end=1.0)
+ assert seg.speaker == "SPEAKER_00"
+ assert seg.is_user is False
+ assert seg.person_id is None
+ assert seg.speaker_id is None
+
+ def test_request_defaults(self, valid_segments):
+ req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments)
+ assert req.source == "desktop"
+ assert req.language == "en"
+ assert req.started_at is None
+ assert req.finished_at is None
+ assert req.geolocation is None
+
+ def test_response_model(self):
+ resp = FromSegmentsResponse(id="conv123", status="completed", discarded=False)
+ assert resp.id == "conv123"
+ assert resp.status == "completed"
+ assert resp.discarded is False
+
+
+class TestFromSegmentsValidation:
+ def test_segment_with_all_fields(self):
+ seg = FromSegmentsTranscriptSegment(
+ text="Hello",
+ speaker="SPEAKER_01",
+ speaker_id=1,
+ is_user=True,
+ person_id="person123",
+ start=10.5,
+ end=15.3,
+ )
+ assert seg.speaker_id == 1
+ assert seg.person_id == "person123"
+
+ def test_desktop_source_default(self, valid_segments):
+ req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments)
+ assert req.source == "desktop"
+
+ def test_custom_source(self, valid_segments):
+ req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments, source="phone")
+ assert req.source == "phone"
+
+ def test_started_finished_at(self, valid_segments):
+ now = datetime.now(timezone.utc)
+ later = now + timedelta(minutes=5)
+ req = CreateConversationFromSegmentsRequest(
+ transcript_segments=valid_segments,
+ started_at=now,
+ finished_at=later,
+ )
+ assert req.started_at == now
+ assert req.finished_at == later
+
+ def test_500_segments_accepted(self):
+ segs = [FromSegmentsTranscriptSegment(text=f"seg {i}", start=float(i), end=float(i + 1)) for i in range(500)]
+ req = CreateConversationFromSegmentsRequest(transcript_segments=segs)
+ assert len(req.transcript_segments) == 500
+
+ def test_geolocation_accepted(self, valid_segments):
+ req = CreateConversationFromSegmentsRequest(
+ transcript_segments=valid_segments,
+ geolocation={'latitude': 37.7749, 'longitude': -122.4194},
+ )
+ assert req.geolocation is not None
+
+
+class TestFromSegmentsEndpoint:
+ """Endpoint-level tests using FastAPI TestClient with mocked auth and processing."""
+
+ def _make_app(self):
+ from fastapi import FastAPI
+ from routers.conversations import router
+ app = FastAPI()
+ app.include_router(router)
+ return app
+
+ @pytest.fixture
+ def client(self):
+ from fastapi.testclient import TestClient
+ return TestClient(self._make_app())
+
+ def test_successful_creation(self, client):
+ with (
+ patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'),
+ patch('routers.conversations.process_conversation') as mock_process,
+ patch('routers.conversations.get_google_maps_location'),
+ ):
+ mock_conv = MagicMock()
+ mock_conv.id = 'conv-abc'
+ mock_conv.status.value = 'completed'
+ mock_conv.discarded = False
+ mock_process.return_value = mock_conv
+
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={
+ 'transcript_segments': [
+ {'text': 'Hello there', 'speaker': 'SPEAKER_00', 'is_user': True, 'start': 0.0, 'end': 2.5},
+ {'text': 'Hi!', 'speaker': 'SPEAKER_01', 'is_user': False, 'start': 2.8, 'end': 5.2},
+ ],
+ 'source': 'desktop',
+ 'language': 'en',
+ },
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data['id'] == 'conv-abc'
+ assert data['status'] == 'completed'
+ assert data['discarded'] is False
+ mock_process.assert_called_once()
+
+ def test_invalid_segment_times_returns_422(self, client):
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'):
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': [{'text': 'Hello', 'start': 5.0, 'end': 3.0}]},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 422
+
+ def test_empty_text_returns_422(self, client):
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'):
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': [{'text': ' ', 'start': 0.0, 'end': 1.0}]},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 422
+
+ def test_negative_start_returns_422(self, client):
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'):
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': [{'text': 'Hello', 'start': -1.0, 'end': 1.0}]},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 422
+
+ def test_finished_at_auto_calculated(self, client):
+ with (
+ patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'),
+ patch('routers.conversations.process_conversation') as mock_process,
+ patch('routers.conversations.get_google_maps_location'),
+ ):
+ mock_conv = MagicMock()
+ mock_conv.id = 'conv-calc'
+ mock_conv.status.value = 'completed'
+ mock_conv.discarded = False
+ mock_process.return_value = mock_conv
+
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 30.0}], 'source': 'desktop'},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 200
+ create_obj = mock_process.call_args[0][2]
+ assert create_obj.finished_at > create_obj.started_at
+
+ def test_source_defaults_to_desktop(self, client):
+ with (
+ patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'),
+ patch('routers.conversations.process_conversation') as mock_process,
+ patch('routers.conversations.get_google_maps_location'),
+ ):
+ mock_conv = MagicMock()
+ mock_conv.id = 'conv-def'
+ mock_conv.status.value = 'completed'
+ mock_conv.discarded = False
+ mock_process.return_value = mock_conv
+
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 1.0}]},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 200
+ create_obj = mock_process.call_args[0][2]
+ assert create_obj.source.value == 'desktop'
+
+ def test_empty_segments_list_returns_422(self, client):
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'):
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': []},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 422
+
+ def test_exactly_500_segments_succeeds(self, client):
+ with (
+ patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'),
+ patch('routers.conversations.process_conversation') as mock_process,
+ patch('routers.conversations.get_google_maps_location'),
+ ):
+ mock_conv = MagicMock()
+ mock_conv.id = 'conv-500'
+ mock_conv.status.value = 'completed'
+ mock_conv.discarded = False
+ mock_process.return_value = mock_conv
+
+ segments = [{'text': f'seg {i}', 'start': float(i), 'end': float(i + 1)} for i in range(500)]
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': segments},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 200
+
+ def test_over_500_segments_returns_422(self, client):
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'):
+ segments = [{'text': f'seg {i}', 'start': float(i), 'end': float(i + 1)} for i in range(501)]
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={'transcript_segments': segments},
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 422
+ assert '500' in response.json()['detail']
+
+ def test_finished_at_before_started_at_returns_422(self, client):
+ with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'):
+ now = datetime.now(timezone.utc)
+ earlier = now - timedelta(hours=1)
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={
+ 'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 1.0}],
+ 'started_at': now.isoformat(),
+ 'finished_at': earlier.isoformat(),
+ },
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 422
+ assert 'finished_at' in response.json()['detail']
+
+ def test_geolocation_enrichment_failure_continues(self, client):
+ with (
+ patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'),
+ patch('routers.conversations.process_conversation') as mock_process,
+ patch('routers.conversations.get_google_maps_location', side_effect=Exception('API error')),
+ ):
+ mock_conv = MagicMock()
+ mock_conv.id = 'conv-geo'
+ mock_conv.status.value = 'completed'
+ mock_conv.discarded = False
+ mock_process.return_value = mock_conv
+
+ response = client.post(
+ '/v1/conversations/from-segments',
+ json={
+ 'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 1.0}],
+ 'geolocation': {'latitude': 37.7749, 'longitude': -122.4194},
+ },
+ headers={'Authorization': 'Bearer test-token'},
+ )
+ assert response.status_code == 200
+
+
+# Keep patch import at module scope for the with-statement usage
+from unittest.mock import patch
diff --git a/backend/tests/unit/test_screen_activity_sync.py b/backend/tests/unit/test_screen_activity_sync.py
new file mode 100644
index 0000000000..bb15a6ec40
--- /dev/null
+++ b/backend/tests/unit/test_screen_activity_sync.py
@@ -0,0 +1,94 @@
+import threading
+from unittest.mock import patch, MagicMock
+
+import pytest
+from fastapi.testclient import TestClient
+
+
+@pytest.fixture
+def client():
+ with patch('database.screen_activity.db'), \
+ patch('database.vector_db.Pinecone'), \
+ patch('database.vector_db.pc'), \
+ patch('database.vector_db.index'), \
+ patch('utils.llm.clients.embeddings'):
+ from main import app
+ with TestClient(app) as c:
+ yield c
+
+
+AUTH = {"Authorization": "Bearer 123testuser"}
+
+
+class TestScreenActivitySyncValidation:
+ def test_empty_rows_returns_zero(self, client):
+ resp = client.post("/v1/screen-activity/sync", json={"rows": []}, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json() == {"synced": 0, "last_id": 0}
+
+ def test_exceeds_100_rows_returns_400(self, client):
+ rows = [{"id": i, "timestamp": "2026-01-01T00:00:00Z", "appName": "A", "windowTitle": "W", "ocrText": "x"} for i in range(101)]
+ resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH)
+ assert resp.status_code == 400
+ assert "100" in resp.json()["detail"]
+
+ def test_exactly_100_rows_accepted(self, client):
+ rows = [{"id": i, "timestamp": "2026-01-01T00:00:00Z"} for i in range(100)]
+ with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=100):
+ resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["synced"] == 100
+
+ def test_no_auth_returns_401(self, client):
+ resp = client.post("/v1/screen-activity/sync", json={"rows": []})
+ assert resp.status_code == 401
+
+ def test_last_id_is_max_from_batch(self, client):
+ rows = [
+ {"id": 5, "timestamp": "2026-01-01T00:00:00Z"},
+ {"id": 99, "timestamp": "2026-01-01T00:01:00Z"},
+ {"id": 3, "timestamp": "2026-01-01T00:02:00Z"},
+ ]
+ with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=3):
+ resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH)
+ assert resp.status_code == 200
+ assert resp.json()["last_id"] == 99
+
+ def test_firestore_error_returns_500(self, client):
+ rows = [{"id": 1, "timestamp": "2026-01-01T00:00:00Z"}]
+ with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', side_effect=Exception("Firestore down")):
+ resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH)
+ assert resp.status_code == 500
+
+ def test_rows_with_embeddings_spawn_thread(self, client):
+ rows = [{"id": 1, "timestamp": "2026-01-01T00:00:00Z", "embedding": [0.1] * 3072}]
+ with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=1), \
+ patch('routers.screen_activity.threading.Thread') as mock_thread:
+ mock_thread.return_value = MagicMock()
+ resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH)
+ assert resp.status_code == 200
+ mock_thread.assert_called_once()
+ mock_thread.return_value.start.assert_called_once()
+
+ def test_rows_without_embeddings_no_thread(self, client):
+ rows = [{"id": 1, "timestamp": "2026-01-01T00:00:00Z"}]
+ with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=1), \
+ patch('routers.screen_activity.threading.Thread') as mock_thread:
+ resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH)
+ assert resp.status_code == 200
+ mock_thread.assert_not_called()
+
+
+class TestUpsertVectorsBackground:
+ def test_vector_upsert_exception_is_logged(self):
+ from routers.screen_activity import _upsert_vectors_background
+ with patch('routers.screen_activity.vector_db.upsert_screen_activity_vectors', side_effect=Exception("Pinecone down")), \
+ patch('routers.screen_activity.logger') as mock_logger:
+ _upsert_vectors_background("uid123", [{"id": 1, "embedding": [0.1]}])
+ mock_logger.exception.assert_called_once()
+
+ def test_vector_upsert_success(self):
+ from routers.screen_activity import _upsert_vectors_background
+ with patch('routers.screen_activity.vector_db.upsert_screen_activity_vectors') as mock_upsert:
+ _upsert_vectors_background("uid123", [{"id": 1, "embedding": [0.1]}])
+ mock_upsert.assert_called_once_with("uid123", [{"id": 1, "embedding": [0.1]}])
diff --git a/backend/tests/unit/test_staged_tasks.py b/backend/tests/unit/test_staged_tasks.py
new file mode 100644
index 0000000000..60c414d8f8
--- /dev/null
+++ b/backend/tests/unit/test_staged_tasks.py
@@ -0,0 +1,817 @@
+"""Tests for desktop staged tasks + daily scores endpoints."""
+
+import sys
+from unittest.mock import patch, MagicMock
+from datetime import datetime, timezone
+
+import pytest
+
+for mod_name in [
+ 'firebase_admin',
+ 'firebase_admin.auth',
+ 'firebase_admin.firestore',
+ 'firebase_admin.messaging',
+ 'google.cloud',
+ 'google.cloud.exceptions',
+ 'google.cloud.firestore',
+ 'google.cloud.firestore_v1',
+ 'google.cloud.firestore_v1.base_query',
+ 'google.cloud.firestore_v1.query',
+ 'google.cloud.storage',
+ 'google.cloud.storage.blob',
+ 'google.cloud.storage.bucket',
+ 'google.auth',
+ 'google.auth.transport',
+ 'google.auth.transport.requests',
+ 'google.oauth2',
+ 'google.oauth2.service_account',
+ 'pinecone',
+ 'typesense',
+]:
+ sys.modules.setdefault(mod_name, MagicMock())
+
+from routers.staged_tasks import (
+ CreateStagedTaskRequest,
+ StagedTaskResponse,
+ StagedTasksListResponse,
+ BatchUpdateScoresRequest,
+ ScoreUpdate,
+ PromoteResponse,
+ DailyScoreResponse,
+ ScoresResponse,
+ ScoreData,
+ StatusResponse,
+ router,
+)
+
+# --- Model Tests ---
+
+
+class TestStagedTaskModels:
+ def test_create_request_required_fields(self):
+ req = CreateStagedTaskRequest(description='Buy groceries')
+ assert req.description == 'Buy groceries'
+ assert req.source is None
+ assert req.relevance_score is None
+
+ def test_create_request_all_fields(self):
+ req = CreateStagedTaskRequest(
+ description='Ship feature',
+ source='screenshot',
+ priority='high',
+ metadata='{"app": "Safari"}',
+ category='work',
+ relevance_score=3,
+ )
+ assert req.priority == 'high'
+ assert req.relevance_score == 3
+
+ def test_create_request_blank_description_rejected(self):
+ with pytest.raises(Exception):
+ CreateStagedTaskRequest(description=' ')
+
+ def test_batch_scores_request(self):
+ req = BatchUpdateScoresRequest(scores=[ScoreUpdate(id='t1', relevance_score=5)])
+ assert len(req.scores) == 1
+
+ def test_batch_scores_empty_rejected(self):
+ with pytest.raises(Exception):
+ BatchUpdateScoresRequest(scores=[])
+
+ def test_promote_response(self):
+ resp = PromoteResponse(promoted=True, promoted_task=StagedTaskResponse(id='t1', description='Task'))
+ assert resp.promoted is True
+ assert resp.promoted_task.id == 't1'
+
+ def test_daily_score_response(self):
+ resp = DailyScoreResponse(score=75.0, completed_tasks=3, total_tasks=4, date='2026-03-05')
+ assert resp.score == 75.0
+
+ def test_scores_response(self):
+ data = ScoreData(score=50.0, completed_tasks=1, total_tasks=2)
+ resp = ScoresResponse(daily=data, weekly=data, overall=data, default_tab='daily', date='2026-03-05')
+ assert resp.default_tab == 'daily'
+
+
+# --- Endpoint Tests ---
+
+
+class TestStagedTaskEndpoints:
+ def _make_app(self):
+ from fastapi import FastAPI
+
+ app = FastAPI()
+ app.include_router(router)
+ return app
+
+ @pytest.fixture
+ def client(self):
+ from fastapi.testclient import TestClient
+
+ return TestClient(self._make_app())
+
+ def test_create_staged_task(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.create_staged_task') as mock_create,
+ ):
+ mock_create.return_value = {
+ 'id': 'st-1',
+ 'description': 'Buy milk',
+ 'completed': False,
+ 'created_at': datetime.now(timezone.utc),
+ 'updated_at': datetime.now(timezone.utc),
+ }
+ response = client.post(
+ '/v1/staged-tasks',
+ json={'description': 'Buy milk', 'source': 'screenshot', 'relevance_score': 5},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert response.json()['id'] == 'st-1'
+ assert response.json()['description'] == 'Buy milk'
+
+ def test_create_staged_task_blank_desc_422(self, client):
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.post(
+ '/v1/staged-tasks',
+ json={'description': ' '},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_list_staged_tasks(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_get,
+ ):
+ mock_get.return_value = (
+ [
+ {'id': 'st-1', 'description': 'Task 1', 'completed': False, 'relevance_score': 1},
+ {'id': 'st-2', 'description': 'Task 2', 'completed': False, 'relevance_score': 3},
+ ],
+ False,
+ )
+ response = client.get('/v1/staged-tasks', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert len(data['items']) == 2
+ assert data['has_more'] is False
+
+ def test_list_staged_tasks_with_pagination(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_get,
+ ):
+ mock_get.return_value = ([], True)
+ response = client.get(
+ '/v1/staged-tasks?limit=10&offset=20',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_get.called
+ assert mock_get.call_args[1] == {'limit': 10, 'offset': 20}
+
+ def test_list_staged_tasks_limit_over_max_422(self, client):
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get(
+ '/v1/staged-tasks?limit=501',
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_delete_staged_task(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task') as mock_del,
+ ):
+ response = client.delete('/v1/staged-tasks/st-1', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['status'] == 'ok'
+ assert mock_del.called
+
+ def test_delete_staged_task_idempotent(self, client):
+ """Delete returns 200 even for non-existent task (matches Rust behavior)."""
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'),
+ ):
+ response = client.delete('/v1/staged-tasks/missing', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['status'] == 'ok'
+
+ def test_batch_update_scores(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.batch_update_scores') as mock_batch,
+ ):
+ response = client.patch(
+ '/v1/staged-tasks/batch-scores',
+ json={'scores': [{'id': 'st-1', 'relevance_score': 10}, {'id': 'st-2', 'relevance_score': 3}]},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert mock_batch.called
+ assert len(mock_batch.call_args[0][1]) == 2
+
+ def test_batch_update_scores_empty_422(self, client):
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.patch(
+ '/v1/staged-tasks/batch-scores',
+ json={'scores': []},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_promote_success(self, client):
+ now = datetime.now(timezone.utc)
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items', return_value=[]),
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged,
+ patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote,
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'),
+ ):
+ mock_staged.return_value = (
+ [
+ {'id': 'st-1', 'description': 'Top task', 'completed': False, 'relevance_score': 1},
+ ],
+ False,
+ )
+ mock_promote.return_value = {
+ 'id': 'ai-1',
+ 'description': 'Top task',
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ 'from_staged': True,
+ }
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert data['promoted'] is True
+ assert data['promoted_task']['id'] == 'ai-1'
+
+ def test_promote_max_active_returns_false(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active,
+ ):
+ mock_active.return_value = [{'id': f'ai-{i}', 'description': f'Task {i}'} for i in range(5)]
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert data['promoted'] is False
+ assert 'max 5' in data['reason']
+
+ def test_promote_no_staged_tasks(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items', return_value=[]),
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks', return_value=([], False)),
+ ):
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['promoted'] is False
+ assert 'No staged tasks' in response.json()['reason']
+
+ def test_promote_skips_duplicates(self, client):
+ now = datetime.now(timezone.utc)
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active,
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged,
+ patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote,
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'),
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch') as mock_batch_del,
+ ):
+ mock_active.return_value = [{'id': 'ai-1', 'description': 'Buy groceries'}]
+ mock_staged.return_value = (
+ [
+ {'id': 'st-1', 'description': 'buy groceries', 'completed': False, 'relevance_score': 1},
+ {'id': 'st-2', 'description': 'Ship feature', 'completed': False, 'relevance_score': 2},
+ ],
+ False,
+ )
+ mock_promote.return_value = {
+ 'id': 'ai-2',
+ 'description': 'Ship feature',
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ }
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['promoted'] is True
+ assert response.json()['promoted_task']['description'] == 'Ship feature'
+ # st-1 should be batch-deleted as duplicate
+ assert mock_batch_del.called
+ assert mock_batch_del.call_args[0][1] == ['st-1']
+
+ def test_promote_all_duplicates(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active,
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged,
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch'),
+ ):
+ mock_active.return_value = [{'id': 'ai-1', 'description': 'Task A'}]
+ mock_staged.return_value = (
+ [
+ {'id': 'st-1', 'description': 'task a', 'completed': False, 'relevance_score': 1},
+ ],
+ False,
+ )
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['promoted'] is False
+ assert 'duplicates' in response.json()['reason']
+
+
+class TestDailyScoreEndpoints:
+ def _make_app(self):
+ from fastapi import FastAPI
+
+ app = FastAPI()
+ app.include_router(router)
+ return app
+
+ @pytest.fixture
+ def client(self):
+ from fastapi.testclient import TestClient
+
+ return TestClient(self._make_app())
+
+ def test_daily_score_today(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(3, 4)),
+ ):
+ response = client.get('/v1/daily-score', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert data['score'] == 75.0
+ assert data['completed_tasks'] == 3
+ assert data['total_tasks'] == 4
+
+ def test_daily_score_specific_date(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(0, 0)),
+ ):
+ response = client.get('/v1/daily-score?date=2026-01-15', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['date'] == '2026-01-15'
+ assert response.json()['score'] == 0.0
+
+ def test_daily_score_invalid_date_400(self, client):
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get('/v1/daily-score?date=not-a-date', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 400
+
+ def test_scores_all_three(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(2, 4)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(10, 20)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(50, 100)),
+ ):
+ response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert data['daily']['score'] == 50.0
+ assert data['weekly']['score'] == 50.0
+ assert data['overall']['score'] == 50.0
+
+ def test_scores_default_tab_daily_when_highest(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(4, 4)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(5, 10)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(10, 30)),
+ ):
+ response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'})
+ assert response.json()['default_tab'] == 'daily'
+
+ def test_scores_default_tab_weekly_when_no_daily(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(0, 0)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(5, 10)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(10, 30)),
+ ):
+ response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'})
+ assert response.json()['default_tab'] == 'weekly'
+
+ def test_scores_invalid_date_400(self, client):
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get('/v1/scores?date=bad', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 400
+
+ def test_scores_no_tasks_zero(self, client):
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(0, 0)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(0, 0)),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(0, 0)),
+ ):
+ response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ data = response.json()
+ assert data['daily']['score'] == 0.0
+ assert data['weekly']['score'] == 0.0
+ assert data['overall']['score'] == 0.0
+
+ def test_create_dedup_returns_existing(self, client):
+ """Create returns existing task if description matches (case-insensitive)."""
+ now = datetime.now(timezone.utc)
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.create_staged_task') as mock_create,
+ ):
+ # Simulate dedup returning existing task
+ mock_create.return_value = {
+ 'id': 'existing-1',
+ 'description': 'Buy milk',
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ }
+ response = client.post(
+ '/v1/staged-tasks',
+ json={'description': 'buy milk'},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+ assert response.json()['id'] == 'existing-1'
+
+ def test_weekly_score_uses_created_at(self, client):
+ """Weekly score filters by created_at range, not due_at."""
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(1, 2)),
+ patch(
+ 'routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(7, 14)
+ ) as mock_weekly,
+ patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(20, 40)),
+ ):
+ response = client.get('/v1/scores?date=2026-03-05', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert mock_weekly.called
+ # Weekly should use a 7-day window ending today
+ week_start_arg = mock_weekly.call_args[0][1]
+ assert '2026-02-26' in week_start_arg
+
+ # --- Promote with [screen] prefix/suffix normalization ---
+
+ def test_promote_skips_screen_prefix_duplicate(self, client):
+ """Promote dedup strips [screen] prefix when comparing descriptions."""
+ now = datetime.now(timezone.utc)
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active,
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged,
+ patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote,
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'),
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch') as mock_batch_del,
+ ):
+ # Active item without [screen] prefix
+ mock_active.return_value = [{'id': 'ai-1', 'description': 'Buy milk'}]
+ # Staged item with [screen] prefix — should be detected as duplicate
+ mock_staged.return_value = (
+ [
+ {'id': 'st-1', 'description': '[screen] Buy milk', 'completed': False, 'relevance_score': 1},
+ {'id': 'st-2', 'description': 'New unique task', 'completed': False, 'relevance_score': 2},
+ ],
+ False,
+ )
+ mock_promote.return_value = {
+ 'id': 'ai-2',
+ 'description': 'New unique task',
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ }
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['promoted'] is True
+ assert response.json()['promoted_task']['description'] == 'New unique task'
+ # st-1 with [screen] prefix should be deleted as duplicate
+ assert mock_batch_del.called
+ assert 'st-1' in mock_batch_del.call_args[0][1]
+
+ def test_promote_skips_screen_suffix_duplicate(self, client):
+ """Promote dedup strips [screen] suffix when comparing descriptions."""
+ now = datetime.now(timezone.utc)
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active,
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged,
+ patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote,
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'),
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch') as mock_batch_del,
+ ):
+ # Active item with [screen] suffix
+ mock_active.return_value = [{'id': 'ai-1', 'description': 'Buy milk [screen]'}]
+ # Staged item without [screen] — should be detected as duplicate
+ mock_staged.return_value = (
+ [
+ {'id': 'st-1', 'description': 'buy milk', 'completed': False, 'relevance_score': 1},
+ {'id': 'st-2', 'description': 'Different task', 'completed': False, 'relevance_score': 2},
+ ],
+ False,
+ )
+ mock_promote.return_value = {
+ 'id': 'ai-2',
+ 'description': 'Different task',
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ }
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['promoted'] is True
+ # st-1 should be deleted as duplicate
+ assert mock_batch_del.called
+ assert 'st-1' in mock_batch_del.call_args[0][1]
+
+ # --- Promote boundary: 4 active should still promote ---
+
+ def test_promote_with_4_active_succeeds(self, client):
+ """Promote succeeds when exactly 4 active AI tasks (under max 5)."""
+ now = datetime.now(timezone.utc)
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active,
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged,
+ patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote,
+ patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'),
+ ):
+ mock_active.return_value = [{'id': f'ai-{i}', 'description': f'Task {i}'} for i in range(4)]
+ mock_staged.return_value = (
+ [{'id': 'st-1', 'description': 'New task', 'completed': False, 'relevance_score': 1}],
+ False,
+ )
+ mock_promote.return_value = {
+ 'id': 'ai-5',
+ 'description': 'New task',
+ 'completed': False,
+ 'created_at': now,
+ 'updated_at': now,
+ }
+ response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+ assert response.json()['promoted'] is True
+
+ # --- Cap boundary tests ---
+
+ def test_create_description_max_length_accepted(self, client):
+ """Description at exactly 2000 chars is accepted."""
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.create_staged_task') as mock_create,
+ ):
+ desc = 'A' * 2000
+ mock_create.return_value = {
+ 'id': 'st-1',
+ 'description': desc,
+ 'completed': False,
+ }
+ response = client.post(
+ '/v1/staged-tasks',
+ json={'description': desc},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 200
+
+ def test_create_description_over_max_rejected(self, client):
+ """Description at 2001 chars is rejected."""
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.post(
+ '/v1/staged-tasks',
+ json={'description': 'A' * 2001},
+ headers={'Authorization': 'Bearer test'},
+ )
+ assert response.status_code == 422
+
+ def test_list_limit_1_accepted(self, client):
+ """List with limit=1 is accepted."""
+ with (
+ patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'),
+ patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks', return_value=([], False)),
+ ):
+ response = client.get('/v1/staged-tasks?limit=1', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 200
+
+ def test_list_limit_0_rejected(self, client):
+ """List with limit=0 is rejected (min 1)."""
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get('/v1/staged-tasks?limit=0', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 422
+
+ def test_list_offset_negative_rejected(self, client):
+ """List with offset=-1 is rejected (min 0)."""
+ with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'):
+ response = client.get('/v1/staged-tasks?offset=-1', headers={'Authorization': 'Bearer test'})
+ assert response.status_code == 422
+
+
+# --- DB Unit Tests ---
+
+
+class _MockDoc:
+ """Mock Firestore document snapshot."""
+
+ def __init__(self, doc_id, data, exists=True):
+ self.id = doc_id
+ self._data = data
+ self.exists = exists
+
+ def to_dict(self):
+ return self._data.copy()
+
+
+class TestStagedTasksDB:
+ """Unit tests for database/staged_tasks.py functions with mocked Firestore."""
+
+ def test_create_dedup_case_insensitive(self):
+ """create_staged_task returns existing task if description matches case-insensitively."""
+ import database.staged_tasks as db_mod
+
+ existing_doc = _MockDoc('existing-1', {'description': 'Buy Milk', 'completed': False})
+ mock_ref = MagicMock()
+ mock_ref.stream.return_value = [existing_doc]
+
+ with patch.object(db_mod, 'db') as mock_db:
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref
+ result = db_mod.create_staged_task('uid-1', {'description': 'buy milk'})
+ assert result['id'] == 'existing-1'
+ assert result['description'] == 'Buy Milk'
+ # Should NOT have called add (dedup returned existing)
+ mock_ref.add.assert_not_called()
+
+ def test_create_dedup_whitespace_trim(self):
+ """create_staged_task trims whitespace before dedup comparison."""
+ import database.staged_tasks as db_mod
+
+ existing_doc = _MockDoc('existing-1', {'description': 'Buy Milk', 'completed': False})
+ mock_ref = MagicMock()
+ mock_ref.stream.return_value = [existing_doc]
+
+ with patch.object(db_mod, 'db') as mock_db:
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref
+ result = db_mod.create_staged_task('uid-1', {'description': ' buy milk '})
+ assert result['id'] == 'existing-1'
+ mock_ref.add.assert_not_called()
+
+ def test_create_dedup_skips_deleted(self):
+ """create_staged_task ignores soft-deleted tasks during dedup scan."""
+ import database.staged_tasks as db_mod
+
+ deleted_doc = _MockDoc('del-1', {'description': 'Buy Milk', 'completed': False, 'deleted': True})
+ mock_ref = MagicMock()
+ mock_ref.stream.return_value = [deleted_doc]
+ mock_ref.add.return_value = (None, MagicMock(id='new-1'))
+
+ with patch.object(db_mod, 'db') as mock_db:
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref
+ result = db_mod.create_staged_task('uid-1', {'description': 'Buy Milk'})
+ # Should create new since deleted match doesn't count
+ assert result['id'] == 'new-1'
+ mock_ref.add.assert_called_once()
+
+ def test_create_empty_description_raises(self):
+ """create_staged_task raises ValueError for empty/whitespace description."""
+ import database.staged_tasks as db_mod
+
+ with pytest.raises(ValueError, match='description must not be empty'):
+ db_mod.create_staged_task('uid-1', {'description': ' '})
+
+ def test_get_staged_tasks_filters_completed_and_deleted(self):
+ """get_staged_tasks uses completed=false filter and skips deleted client-side."""
+ import database.staged_tasks as db_mod
+
+ docs = [
+ _MockDoc('t-1', {'description': 'Active', 'completed': False, 'relevance_score': 1}),
+ _MockDoc('t-2', {'description': 'Deleted', 'completed': False, 'deleted': True, 'relevance_score': 2}),
+ _MockDoc('t-3', {'description': 'Also active', 'completed': False, 'relevance_score': 3}),
+ ]
+
+ mock_query = MagicMock()
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+ mock_query.limit.return_value = mock_query
+ mock_query.stream.return_value = docs
+
+ with patch.object(db_mod, 'db') as mock_db:
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_query
+ items, has_more = db_mod.get_staged_tasks('uid-1', limit=10)
+ # Should have 2 items (t-2 is deleted, filtered out)
+ assert len(items) == 2
+ assert items[0]['id'] == 't-1'
+ assert items[1]['id'] == 't-3'
+ assert has_more is False
+
+ def test_get_staged_tasks_queries_completed_false(self):
+ """get_staged_tasks passes completed=false FieldFilter to Firestore."""
+ import database.staged_tasks as db_mod
+
+ mock_query = MagicMock()
+ mock_query.where.return_value = mock_query
+ mock_query.order_by.return_value = mock_query
+ mock_query.limit.return_value = mock_query
+ mock_query.stream.return_value = []
+
+ with (
+ patch.object(db_mod, 'db') as mock_db,
+ patch.object(db_mod, 'firestore') as mock_fs,
+ ):
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_query
+ mock_fs.FieldFilter.return_value = 'completed_filter'
+ mock_fs.Query.ASCENDING = 'ASC'
+ mock_fs.Query.DESCENDING = 'DESC'
+
+ db_mod.get_staged_tasks('uid-1')
+
+ # Verify FieldFilter was called with completed=false
+ mock_fs.FieldFilter.assert_called_once_with('completed', '==', False)
+ mock_query.where.assert_called_once_with(filter='completed_filter')
+
+ def test_daily_score_uses_due_at(self):
+ """get_action_items_for_daily_score filters by due_at range."""
+ import database.staged_tasks as db_mod
+
+ mock_query = MagicMock()
+ mock_query.where.return_value = mock_query
+ mock_query.stream.return_value = []
+
+ with (
+ patch.object(db_mod, 'db') as mock_db,
+ patch.object(db_mod, 'firestore') as mock_fs,
+ ):
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_query
+ mock_fs.FieldFilter.side_effect = lambda field, op, val: f'{field}_{op}_{val}'
+
+ db_mod.get_action_items_for_daily_score('uid-1', '2026-03-05T00:00:00Z', '2026-03-05T23:59:59.999Z')
+
+ # Should have called FieldFilter with 'due_at' (not 'created_at')
+ calls = mock_fs.FieldFilter.call_args_list
+ fields_used = [c[0][0] for c in calls]
+ assert 'due_at' in fields_used
+ assert 'created_at' not in fields_used
+
+ def test_weekly_score_uses_created_at(self):
+ """get_action_items_for_weekly_score filters by created_at range (not due_at)."""
+ import database.staged_tasks as db_mod
+
+ mock_query = MagicMock()
+ mock_query.where.return_value = mock_query
+ mock_query.stream.return_value = []
+
+ with (
+ patch.object(db_mod, 'db') as mock_db,
+ patch.object(db_mod, 'firestore') as mock_fs,
+ ):
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_query
+ mock_fs.FieldFilter.side_effect = lambda field, op, val: f'{field}_{op}_{val}'
+
+ db_mod.get_action_items_for_weekly_score('uid-1', '2026-02-26T00:00:00Z', '2026-03-05T23:59:59.999Z')
+
+ # Should have called FieldFilter with 'created_at' (not 'due_at')
+ calls = mock_fs.FieldFilter.call_args_list
+ fields_used = [c[0][0] for c in calls]
+ assert 'created_at' in fields_used
+ assert 'due_at' not in fields_used
+
+ def test_overall_score_counts_all_non_deleted(self):
+ """get_action_items_for_overall_score scans all docs, skips deleted."""
+ import database.staged_tasks as db_mod
+
+ docs = [
+ _MockDoc('a-1', {'completed': True}),
+ _MockDoc('a-2', {'completed': False}),
+ _MockDoc('a-3', {'completed': True, 'deleted': True}), # Should be skipped
+ _MockDoc('a-4', {'completed': False}),
+ ]
+
+ mock_ref = MagicMock()
+ mock_ref.stream.return_value = docs
+
+ with patch.object(db_mod, 'db') as mock_db:
+ mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref
+ completed, total = db_mod.get_action_items_for_overall_score('uid-1')
+ assert completed == 1 # Only a-1 (a-3 is deleted)
+ assert total == 3 # a-1, a-2, a-4 (a-3 is deleted)
+
+ def test_delete_is_idempotent(self):
+ """delete_staged_task calls Firestore delete without checking existence."""
+ import database.staged_tasks as db_mod
+
+ mock_doc_ref = MagicMock()
+ with patch.object(db_mod, 'db') as mock_db:
+ mock_db.collection.return_value.document.return_value.collection.return_value.document.return_value = (
+ mock_doc_ref
+ )
+ # Should not raise even if doc doesn't exist
+ db_mod.delete_staged_task('uid-1', 'nonexistent-id')
+ mock_doc_ref.delete.assert_called_once()
diff --git a/backend/utils/desktop/__init__.py b/backend/utils/desktop/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/backend/utils/desktop/advice.py b/backend/utils/desktop/advice.py
new file mode 100644
index 0000000000..c73a7ae251
--- /dev/null
+++ b/backend/utils/desktop/advice.py
@@ -0,0 +1,115 @@
+import logging
+from typing import Optional
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from database.goals import get_user_goals
+from database.action_items import get_action_items
+from utils.llm.clients import llm_gemini_flash
+
+logger = logging.getLogger(__name__)
+
+ADVICE_SYSTEM_PROMPT = """\
+You are a proactive assistant that offers brief, actionable advice based on what the user \
+is currently doing on their screen. Your advice should be contextual and helpful.
+
+ADVICE RULES:
+- Only offer advice when you can provide genuinely useful, specific guidance
+- Advice must relate to what's visible on screen
+- Keep it short (1-2 sentences max)
+- Be actionable — tell the user something they can DO, not just observe
+- Consider the user's goals and tasks when forming advice
+- ~70% of screenshots need NO advice — return null when nothing useful to say
+
+TONE:
+- Direct and casual, not formal
+- Helpful, not preachy
+- Specific to what you see, not generic productivity tips
+
+CATEGORIES:
+- productivity: efficiency tips, workflow improvements
+- mistake_prevention: catching potential errors or oversights
+- learning: suggesting resources or approaches
+- health: break reminders, posture, eye strain (only if clearly needed)
+- goal_alignment: connecting current activity to stated goals"""
+
+
+class AdviceResult(BaseModel):
+ has_advice: bool = Field(description="Whether advice is warranted")
+ content: Optional[str] = Field(default=None, description="The advice (1-2 sentences, null if none)")
+ category: Optional[str] = Field(
+ default=None, description="productivity|mistake_prevention|learning|health|goal_alignment"
+ )
+ confidence: float = Field(ge=0.0, le=1.0, description="Confidence this advice is useful")
+
+
+def _build_advice_context(uid: str) -> str:
+ """Build user context for advice generation."""
+ parts = []
+
+ try:
+ goals = get_user_goals(uid, limit=5)
+ if goals:
+ goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals]
+ parts.append("User's goals:\n" + "\n".join(goal_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch goals for advice: {e}")
+
+ try:
+ tasks = get_action_items(uid, completed=False, limit=10)
+ if tasks:
+ task_lines = [f"- {t.get('description', '')}" for t in tasks[:10]]
+ parts.append("Current tasks:\n" + "\n".join(task_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch tasks for advice: {e}")
+
+ return "\n\n".join(parts) if parts else ""
+
+
+async def generate_advice(
+ uid: str,
+ image_b64: str,
+ app_name: str = "",
+ window_title: str = "",
+) -> dict:
+ """Generate contextual advice from a screenshot using vision LLM.
+
+ Returns:
+ Dict with has_advice, content, category, confidence (or nulls if no advice)
+ """
+ advice_context = _build_advice_context(uid)
+
+ prompt_parts = []
+ if advice_context:
+ prompt_parts.append(advice_context)
+ if app_name or window_title:
+ prompt_parts.append(f"Current app: {app_name}, Window: {window_title}")
+ prompt_parts.append("Based on this screenshot, do you have any specific, actionable advice?")
+
+ prompt_text = "\n\n".join(prompt_parts)
+
+ with_parser = llm_gemini_flash.with_structured_output(AdviceResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=ADVICE_SYSTEM_PROMPT),
+ HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_text},
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
+ ]
+ ),
+ ]
+ )
+
+ if not result.has_advice:
+ return {"has_advice": False, "advice": None}
+
+ return {
+ "has_advice": True,
+ "advice": {
+ "content": result.content,
+ "category": result.category,
+ "confidence": result.confidence,
+ },
+ }
diff --git a/backend/utils/desktop/focus.py b/backend/utils/desktop/focus.py
new file mode 100644
index 0000000000..6807eeded9
--- /dev/null
+++ b/backend/utils/desktop/focus.py
@@ -0,0 +1,149 @@
+import logging
+from typing import Optional
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from database.goals import get_user_goals
+from database.action_items import get_action_items
+from database.memories import get_memories
+from utils.llm.clients import llm_gemini_flash
+
+logger = logging.getLogger(__name__)
+
+# Match the desktop FocusAssistant's ScreenAnalysis schema
+FOCUS_SYSTEM_PROMPT = """You are a focus coach. Analyze the PRIMARY/MAIN window in screenshots to determine \
+if the user is focused or distracted.
+
+IMPORTANT: Look at the MAIN APPLICATION WINDOW, not log text or terminal output. \
+If you see a code editor with logs that mention "YouTube" - that's just log text, \
+the user is CODING, not on YouTube. Text in logs/terminals mentioning a site does \
+NOT mean the user is on that site.
+
+CONTEXT-AWARE ANALYSIS:
+Each request may include the user's active goals, current tasks, recent memories, \
+and analysis history. Use this context when available, but DO NOT let it prevent you \
+from flagging obvious distractions.
+
+- GOALS & TASKS: If the user's screen activity clearly relates to their active \
+goals or current tasks, they are FOCUSED.
+- HISTORY: Use recent analysis history to notice patterns, acknowledge transitions, \
+and vary your responses.
+
+Set status to "distracted" if the PRIMARY window is:
+- YouTube, Twitch, Netflix, TikTok (actual video site visible, not just text mentioning it)
+- Social media feeds: Twitter/X, Instagram, Facebook, Reddit (casual browsing, not researching)
+- News sites, entertainment sites, games
+- Any content consumption with no clear work purpose
+
+Set status to "focused" if the PRIMARY window is:
+- Code editors, IDEs, terminals, command line
+- Documents, spreadsheets, slides, design tools
+- Email, work chat (Slack, Teams), research
+- Browsing that is clearly work-related (Stack Overflow, docs, PRs, Jira, etc.)
+
+When in doubt, lean toward "distracted" — it's better to nudge the user once too \
+often than to silently let them drift.
+
+Always provide a short coaching message (100 characters max for notification banner):
+- If distracted: Create a unique nudge to refocus. Vary your approach — be playful, \
+direct, or motivational.
+- If focused: Acknowledge their work with variety — don't just say "Nice focus!" \
+every time."""
+
+
+class FocusResult(BaseModel):
+ status: str = Field(description='Focus status: "focused" or "distracted"')
+ app_or_site: str = Field(description="Primary app or site in focus")
+ description: str = Field(description="Brief description of what the user is doing")
+ message: Optional[str] = Field(default=None, description="Short coaching message (max 100 chars)")
+
+
+def _build_context(uid: str) -> str:
+ """Build context from user's goals, tasks, and memories (server-side)."""
+ parts = []
+
+ # Goals (up to 10)
+ try:
+ goals = get_user_goals(uid, limit=10)
+ if goals:
+ goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals]
+ parts.append("Active Goals:\n" + "\n".join(goal_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch goals for context: {e}")
+
+ # Tasks (up to 50, not completed)
+ try:
+ tasks = get_action_items(uid, completed=False, limit=50)
+ if tasks:
+ task_lines = [f"- {t.get('description', '')}" for t in tasks[:50]]
+ parts.append("Current Tasks:\n" + "\n".join(task_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch tasks for context: {e}")
+
+ # Recent memories (up to 20, core category)
+ try:
+ memories = get_memories(uid, limit=20, categories=['core'])
+ if memories:
+ mem_lines = [f"- {m.get('structured', {}).get('title', m.get('content', ''))}" for m in memories[:20]]
+ parts.append("Recent Memories:\n" + "\n".join(mem_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch memories for context: {e}")
+
+ return "\n\n".join(parts) if parts else ""
+
+
+async def analyze_focus(
+ uid: str,
+ image_b64: str,
+ app_name: str = "",
+ window_title: str = "",
+ history: str = "",
+) -> dict:
+ """Analyze a screenshot for focus status using vision LLM.
+
+ Args:
+ uid: User ID for fetching context
+ image_b64: Base64-encoded JPEG screenshot
+ app_name: Name of the foreground app
+ window_title: Window title
+ history: Formatted recent analysis history
+
+ Returns:
+ Dict with type, frame_id, status, app_or_site, description, message
+ """
+ # Build context from user data
+ context = _build_context(uid)
+
+ # Assemble prompt
+ prompt_parts = []
+ if context:
+ prompt_parts.append(context)
+ if history:
+ prompt_parts.append(f"Recent activity (oldest to newest):\n{history}")
+ if app_name or window_title:
+ prompt_parts.append(f"Current app: {app_name}, Window: {window_title}")
+ prompt_parts.append("Now analyze this screenshot:")
+
+ prompt_text = "\n\n".join(prompt_parts)
+
+ # Call vision LLM with structured output
+ with_parser = llm_gemini_flash.with_structured_output(FocusResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=FOCUS_SYSTEM_PROMPT),
+ HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_text},
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
+ ]
+ ),
+ ]
+ )
+
+ return {
+ "status": result.status,
+ "app_or_site": result.app_or_site,
+ "description": result.description,
+ "message": result.message,
+ }
diff --git a/backend/utils/desktop/live_notes.py b/backend/utils/desktop/live_notes.py
new file mode 100644
index 0000000000..4c88b878b4
--- /dev/null
+++ b/backend/utils/desktop/live_notes.py
@@ -0,0 +1,55 @@
+import logging
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from utils.llm.clients import llm_mini
+
+logger = logging.getLogger(__name__)
+
+LIVE_NOTES_SYSTEM_PROMPT = """\
+You are a live note-taking assistant. Given a transcript segment, generate a concise, \
+well-structured note that captures the key information.
+
+RULES:
+- Condense transcript into clear, readable notes
+- Preserve important details: names, numbers, decisions, action items
+- Remove filler words, repetition, and hesitation
+- Use bullet points for multiple items
+- Keep notes under 200 words
+- If the transcript is too short or contains no meaningful content, return empty string"""
+
+
+class LiveNoteResult(BaseModel):
+ text: str = Field(description="The generated note (empty string if no meaningful content)")
+
+
+async def generate_live_note(
+ text: str,
+ session_context: str = "",
+) -> dict:
+ """Generate a live note from transcript text.
+
+ Args:
+ text: Transcript text to summarize
+ session_context: Optional session context
+
+ Returns:
+ Dict with text field (the note)
+ """
+ prompt_parts = []
+ if session_context:
+ prompt_parts.append(f"Session context: {session_context}")
+ prompt_parts.append(f"Transcript:\n{text}")
+
+ prompt_text = "\n\n".join(prompt_parts)
+
+ with_parser = llm_mini.with_structured_output(LiveNoteResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=LIVE_NOTES_SYSTEM_PROMPT),
+ HumanMessage(content=prompt_text),
+ ]
+ )
+
+ return {"text": result.text}
diff --git a/backend/utils/desktop/memories.py b/backend/utils/desktop/memories.py
new file mode 100644
index 0000000000..c1bbbcdbda
--- /dev/null
+++ b/backend/utils/desktop/memories.py
@@ -0,0 +1,100 @@
+import logging
+from typing import List, Optional
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from database.memories import get_memories
+from utils.llm.clients import llm_gemini_flash
+
+logger = logging.getLogger(__name__)
+
+MEMORY_SYSTEM_PROMPT = """\
+You are a memory extraction assistant. Analyze screenshots to identify facts, insights, \
+or noteworthy information worth remembering about the user or their context.
+
+EXTRACTION RULES:
+- Extract facts ABOUT the user: preferences, projects, people they work with, decisions, realizations
+- Extract useful external information: advice, tips, insights from what they're reading
+- Maximum 3 memories per screenshot
+- Each memory should be a concise, standalone fact
+- Skip trivial or transient information (UI state, loading screens, timestamps)
+- ~80% of screenshots contain NO memorable information — return empty list when nothing stands out
+
+DEDUPLICATION:
+- Compare against existing memories provided in context
+- If a fact is already known, skip it
+- Only extract genuinely NEW information
+
+CATEGORIES:
+- system: Facts about the user (preferences, opinions, network, projects, habits)
+- interesting: External wisdom or advice from others (articles, conversations, tips)"""
+
+
+class ExtractedMemory(BaseModel):
+ content: str = Field(description="Concise statement of the fact or insight")
+ category: str = Field(description="system or interesting")
+ confidence: float = Field(ge=0.0, le=1.0, description="Extraction confidence")
+
+
+class MemoryExtractionResult(BaseModel):
+ memories: List[ExtractedMemory] = Field(default_factory=list, description="Extracted memories (empty if none)")
+
+
+def _build_memory_context(uid: str) -> str:
+ """Build existing memories context for deduplication."""
+ try:
+ existing = get_memories(uid, limit=30, categories=['system', 'interesting'])
+ if existing:
+ lines = []
+ for m in existing:
+ content = m.get('structured', {}).get('content', m.get('content', ''))
+ if content:
+ lines.append(f"- {content}")
+ if lines:
+ return "Existing memories (DO NOT extract duplicates):\n" + "\n".join(lines)
+ except Exception as e:
+ logger.warning(f"Failed to fetch existing memories: {e}")
+ return ""
+
+
+async def extract_memories(
+ uid: str,
+ image_b64: str,
+ app_name: str = "",
+ window_title: str = "",
+) -> dict:
+ """Extract memories from a screenshot using vision LLM.
+
+ Returns:
+ Dict with memories list (each has content, category, confidence)
+ """
+ memory_context = _build_memory_context(uid)
+
+ prompt_parts = []
+ if memory_context:
+ prompt_parts.append(memory_context)
+ if app_name or window_title:
+ prompt_parts.append(f"Current app: {app_name}, Window: {window_title}")
+ prompt_parts.append("Analyze this screenshot for noteworthy facts or insights:")
+
+ prompt_text = "\n\n".join(prompt_parts)
+
+ with_parser = llm_gemini_flash.with_structured_output(MemoryExtractionResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=MEMORY_SYSTEM_PROMPT),
+ HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_text},
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
+ ]
+ ),
+ ]
+ )
+
+ return {
+ "memories": [
+ {"content": m.content, "category": m.category, "confidence": m.confidence} for m in result.memories
+ ]
+ }
diff --git a/backend/utils/desktop/profile.py b/backend/utils/desktop/profile.py
new file mode 100644
index 0000000000..84fa697b97
--- /dev/null
+++ b/backend/utils/desktop/profile.py
@@ -0,0 +1,79 @@
+import logging
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from database.memories import get_memories
+from database.action_items import get_action_items
+from database.goals import get_user_goals
+from utils.llm.clients import llm_mini
+
+logger = logging.getLogger(__name__)
+
+PROFILE_SYSTEM_PROMPT = """\
+You are generating a concise user profile summary based on their data (goals, tasks, memories). \
+This profile helps other AI assistants understand who the user is and what they care about.
+
+FORMAT:
+- Write in third person ("The user...")
+- Include: professional focus, key projects, communication style, preferences
+- Keep under 300 words
+- Be factual — only include what's supported by the data
+- If data is sparse, keep the profile short rather than speculating"""
+
+
+class ProfileResult(BaseModel):
+ profile_text: str = Field(description="The generated user profile summary")
+
+
+async def generate_profile(uid: str) -> dict:
+ """Generate a user profile from their goals, tasks, and memories.
+
+ Returns:
+ Dict with profile_text
+ """
+ parts = []
+
+ try:
+ goals = get_user_goals(uid, limit=10)
+ if goals:
+ goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals]
+ parts.append("Goals:\n" + "\n".join(goal_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch goals for profile: {e}")
+
+ try:
+ tasks = get_action_items(uid, completed=False, limit=30)
+ if tasks:
+ task_lines = [f"- {t.get('description', '')}" for t in tasks[:30]]
+ parts.append("Active tasks:\n" + "\n".join(task_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch tasks for profile: {e}")
+
+ try:
+ memories = get_memories(uid, limit=30, categories=['system'])
+ if memories:
+ mem_lines = []
+ for m in memories:
+ content = m.get('structured', {}).get('content', m.get('content', ''))
+ if content:
+ mem_lines.append(f"- {content}")
+ if mem_lines:
+ parts.append("Known facts:\n" + "\n".join(mem_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch memories for profile: {e}")
+
+ if not parts:
+ return {"profile_text": "No data available to generate profile."}
+
+ data_text = "\n\n".join(parts)
+
+ with_parser = llm_mini.with_structured_output(ProfileResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=PROFILE_SYSTEM_PROMPT),
+ HumanMessage(content=f"Generate a user profile from this data:\n\n{data_text}"),
+ ]
+ )
+
+ return {"profile_text": result.profile_text}
diff --git a/backend/utils/desktop/task_ops.py b/backend/utils/desktop/task_ops.py
new file mode 100644
index 0000000000..3e6b7506c0
--- /dev/null
+++ b/backend/utils/desktop/task_ops.py
@@ -0,0 +1,141 @@
+import logging
+from typing import List
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from database.action_items import get_action_items
+from utils.llm.clients import llm_mini
+
+logger = logging.getLogger(__name__)
+
+# --- Task Reranking ---
+
+RERANK_SYSTEM_PROMPT = """\
+You are a task prioritization assistant. Given a list of tasks, rerank them by importance \
+and urgency. Consider deadlines, dependencies, and impact.
+
+RULES:
+- Most important/urgent tasks first
+- Tasks with approaching deadlines rank higher
+- Blocking tasks rank higher than blocked tasks
+- Return the same task IDs in new order"""
+
+
+class RankedTask(BaseModel):
+ id: str = Field(description="Task ID")
+ new_position: int = Field(description="New position (1 = most important)")
+
+
+class RerankResult(BaseModel):
+ updated_tasks: List[RankedTask] = Field(description="Tasks in new priority order")
+
+
+async def rerank_tasks(uid: str) -> dict:
+ """Rerank user's active tasks by priority.
+
+ Returns:
+ Dict with updated_tasks list
+ """
+ try:
+ tasks = get_action_items(uid, completed=False, limit=50)
+ except Exception as e:
+ logger.error(f"Failed to fetch tasks for reranking: {e}")
+ return {"updated_tasks": []}
+
+ if not tasks:
+ return {"updated_tasks": []}
+
+ task_lines = []
+ for t in tasks:
+ tid = t.get('id', '')
+ desc = t.get('description', '')
+ due = t.get('due_at', '')
+ priority = t.get('priority', 'medium')
+ due_str = f", Due: {due}" if due else ""
+ task_lines.append(f"- ID: {tid} | {desc} | Priority: {priority}{due_str}")
+
+ task_text = "\n".join(task_lines)
+
+ with_parser = llm_mini.with_structured_output(RerankResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=RERANK_SYSTEM_PROMPT),
+ HumanMessage(content=f"Rerank these tasks by importance:\n\n{task_text}"),
+ ]
+ )
+
+ return {"updated_tasks": [{"id": t.id, "new_position": t.new_position} for t in result.updated_tasks]}
+
+
+# --- Task Deduplication ---
+
+DEDUP_SYSTEM_PROMPT = """\
+You are a task deduplication assistant. Identify semantically duplicate tasks and decide \
+which to keep and which to delete.
+
+RULES:
+- Two tasks are duplicates if they describe the same action, even with different wording
+- "Call John" and "Phone John" are duplicates
+- "Review PR #42" and "Look at pull request 42" are duplicates
+- Keep the more specific/detailed version
+- Keep the one with a deadline if only one has one
+- Keep the more recently created one if equally specific
+- Only flag true duplicates — similar but distinct tasks should both be kept"""
+
+
+class DedupGroup(BaseModel):
+ keep_id: str = Field(description="ID of the task to keep")
+ delete_ids: List[str] = Field(description="IDs of duplicate tasks to remove")
+ reason: str = Field(description="Why these are duplicates")
+
+
+class DedupResult(BaseModel):
+ groups: List[DedupGroup] = Field(default_factory=list, description="Duplicate groups (empty if no duplicates)")
+
+
+async def dedup_tasks(uid: str) -> dict:
+ """Find and resolve duplicate tasks.
+
+ Returns:
+ Dict with deleted_ids and reason
+ """
+ try:
+ tasks = get_action_items(uid, completed=False, limit=100)
+ except Exception as e:
+ logger.error(f"Failed to fetch tasks for dedup: {e}")
+ return {"deleted_ids": [], "reason": "Failed to fetch tasks"}
+
+ if len(tasks) < 2:
+ return {"deleted_ids": [], "reason": "Not enough tasks to deduplicate"}
+
+ task_lines = []
+ for t in tasks:
+ tid = t.get('id', '')
+ desc = t.get('description', '')
+ due = t.get('due_at', '')
+ created = t.get('created_at', '')
+ due_str = f", Due: {due}" if due else ""
+ created_str = f", Created: {created}" if created else ""
+ task_lines.append(f"- ID: {tid} | {desc}{due_str}{created_str}")
+
+ task_text = "\n".join(task_lines)
+
+ with_parser = llm_mini.with_structured_output(DedupResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=DEDUP_SYSTEM_PROMPT),
+ HumanMessage(content=f"Find duplicate tasks:\n\n{task_text}"),
+ ]
+ )
+
+ all_deleted = []
+ reasons = []
+ for group in result.groups:
+ all_deleted.extend(group.delete_ids)
+ reasons.append(group.reason)
+
+ return {
+ "deleted_ids": all_deleted,
+ "reason": "; ".join(reasons) if reasons else "No duplicates found",
+ }
diff --git a/backend/utils/desktop/tasks.py b/backend/utils/desktop/tasks.py
new file mode 100644
index 0000000000..85b297e633
--- /dev/null
+++ b/backend/utils/desktop/tasks.py
@@ -0,0 +1,156 @@
+import logging
+from typing import List, Optional
+
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+
+from database.action_items import get_action_items
+from utils.llm.clients import llm_gemini_flash
+
+logger = logging.getLogger(__name__)
+
+TASK_SYSTEM_PROMPT = """\
+You are a task extraction assistant. Analyze screenshots to identify actionable tasks, \
+requests, or to-dos visible on screen.
+
+EXTRACTION RULES:
+- Only extract tasks that are clearly visible and actionable
+- Title must be 6+ words, verb-first, naming a specific person/project/artifact + concrete action
+- Skip vague or generic items ("do something", "check this")
+- ~90% of screenshots contain NO new task — use no_tasks when nothing actionable is found
+
+DEDUPLICATION:
+- Compare against the user's existing tasks provided in context
+- If a task is semantically similar to an existing one (even with different wording), skip it
+- "Call John" and "Phone John" are duplicates
+- "Finish report by Friday" and "Complete report by end of week" are duplicates
+- When in doubt, err on treating as duplicate (DON'T extract)
+
+PRIORITY GUIDELINES:
+- high: urgent deadlines, blocking requests, error fixes
+- medium: normal work tasks, follow-ups
+- low: nice-to-haves, ideas, non-urgent items
+
+SOURCE CATEGORIES:
+- direct_request: someone asked the user to do something (message, meeting, mention)
+- self_generated: user's own idea, reminder, or goal subtask
+- calendar_driven: event preparation, recurring task, deadline
+- reactive: error response, notification, observation
+- external_system: from project tools, alerts, documentation"""
+
+
+class ExtractedTask(BaseModel):
+ title: str = Field(description="Verb-first title, 6+ words, specific person/project + concrete action")
+ description: str = Field(default="", description="Additional context if needed")
+ priority: str = Field(description="high, medium, or low")
+ tags: List[str] = Field(default_factory=list, description="1-3 relevant tags")
+ source_app: str = Field(default="", description="App where task was found")
+ inferred_deadline: Optional[str] = Field(default=None, description="yyyy-MM-dd format or null")
+ confidence: float = Field(ge=0.0, le=1.0, description="Extraction confidence")
+ source_category: str = Field(
+ default="reactive", description="direct_request|self_generated|calendar_driven|reactive|external_system"
+ )
+
+
+class TaskExtractionResult(BaseModel):
+ has_new_tasks: bool = Field(description="Whether any new tasks were found")
+ tasks: List[ExtractedTask] = Field(default_factory=list, description="Extracted tasks (empty if none)")
+ context_summary: str = Field(default="", description="Brief summary of what user is viewing")
+ current_activity: str = Field(default="", description="What user is actively doing")
+
+
+def _build_task_context(uid: str) -> str:
+ """Build existing tasks context for deduplication."""
+ parts = []
+
+ try:
+ # Active tasks (not completed) for dedup
+ active_tasks = get_action_items(uid, completed=False, limit=50)
+ if active_tasks:
+ task_lines = []
+ for t in active_tasks:
+ desc = t.get('description', '')
+ due = t.get('due_at', '')
+ due_str = f" (Due: {due})" if due else ""
+ task_lines.append(f"- {desc}{due_str} [Pending]")
+ parts.append("Existing active tasks (DO NOT extract duplicates):\n" + "\n".join(task_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch active tasks for dedup: {e}")
+
+ try:
+ # Recently completed tasks (last 10) for dedup
+ completed_tasks = get_action_items(uid, completed=True, limit=10)
+ if completed_tasks:
+ task_lines = [f"- {t.get('description', '')} [Completed]" for t in completed_tasks[:10]]
+ parts.append("Recently completed tasks:\n" + "\n".join(task_lines))
+ except Exception as e:
+ logger.warning(f"Failed to fetch completed tasks: {e}")
+
+ return "\n\n".join(parts) if parts else ""
+
+
+async def extract_tasks(
+ uid: str,
+ image_b64: str,
+ app_name: str = "",
+ window_title: str = "",
+) -> dict:
+ """Extract tasks from a screenshot using vision LLM.
+
+ Args:
+ uid: User ID for fetching existing tasks (dedup)
+ image_b64: Base64-encoded JPEG screenshot
+ app_name: Name of the foreground app
+ window_title: Window title
+
+ Returns:
+ Dict with has_new_tasks, tasks list, context_summary, current_activity
+ """
+ # Pre-fetch existing tasks for dedup context
+ task_context = _build_task_context(uid)
+
+ # Assemble prompt
+ prompt_parts = []
+ if task_context:
+ prompt_parts.append(task_context)
+ if app_name or window_title:
+ prompt_parts.append(f"Current app: {app_name}, Window: {window_title}")
+ prompt_parts.append("Analyze this screenshot for actionable tasks:")
+
+ prompt_text = "\n\n".join(prompt_parts)
+
+ # Call vision LLM with structured output
+ with_parser = llm_gemini_flash.with_structured_output(TaskExtractionResult)
+ result = await with_parser.ainvoke(
+ [
+ SystemMessage(content=TASK_SYSTEM_PROMPT),
+ HumanMessage(
+ content=[
+ {"type": "text", "text": prompt_text},
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
+ ]
+ ),
+ ]
+ )
+
+ tasks_list = []
+ for task in result.tasks:
+ tasks_list.append(
+ {
+ "title": task.title,
+ "description": task.description,
+ "priority": task.priority,
+ "tags": task.tags,
+ "source_app": task.source_app or app_name,
+ "inferred_deadline": task.inferred_deadline,
+ "confidence": task.confidence,
+ "source_category": task.source_category,
+ }
+ )
+
+ return {
+ "has_new_tasks": result.has_new_tasks and len(tasks_list) > 0,
+ "tasks": tasks_list,
+ "context_summary": result.context_summary,
+ "current_activity": result.current_activity,
+ }
diff --git a/desktop/.env.example b/desktop/.env.example
index 87d0a94fb7..6c25ff1689 100644
--- a/desktop/.env.example
+++ b/desktop/.env.example
@@ -17,9 +17,6 @@
# Production: https://api.omi.me
OMI_API_URL=http://localhost:8080
-# DeepGram API key — required for real-time transcription
-DEEPGRAM_API_KEY=
-
# ─── AI (optional) ──────────────────────────────────────────────────
# Gemini API key for proactive assistants and embeddings
# Falls back to backend-side processing if not set
diff --git a/desktop/CHANGELOG.json b/desktop/CHANGELOG.json
index 72d7549d3c..93b85d73ac 100644
--- a/desktop/CHANGELOG.json
+++ b/desktop/CHANGELOG.json
@@ -1,5 +1,7 @@
{
- "unreleased": [],
+ "unreleased": [
+ "Removed client-side Deepgram API key — transcription now routes securely through the Omi backend"
+ ],
"releases": [
{
"version": "0.11.91",
diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift
index 533f273750..741996ffcc 100644
--- a/desktop/Desktop/Sources/APIClient.swift
+++ b/desktop/Desktop/Sources/APIClient.swift
@@ -405,14 +405,10 @@ extension APIClient {
return response.count
}
- /// Gets the count of AI chat messages from PostHog
+ /// Gets the count of AI chat messages
func getChatMessageCount() async throws -> Int {
- struct CountResponse: Decodable {
- let count: Int
- }
-
- let response: CountResponse = try await get("v1/users/stats/chat-messages")
- return response.count
+ // No-op: chat-messages stats endpoint not available in Python backend
+ return 0
}
/// Merges multiple conversations into a new conversation
@@ -581,7 +577,7 @@ struct ServerConversation: Codable, Identifiable, Equatable {
let container = try decoder.container(keyedBy: CodingKeys.self)
id = try container.decode(String.self, forKey: .id)
- createdAt = try container.decode(Date.self, forKey: .createdAt)
+ createdAt = try container.decodeIfPresent(Date.self, forKey: .createdAt) ?? Date()
startedAt = try container.decodeIfPresent(Date.self, forKey: .startedAt)
finishedAt = try container.decodeIfPresent(Date.self, forKey: .finishedAt)
structured = try container.decode(Structured.self, forKey: .structured)
@@ -1187,8 +1183,6 @@ extension APIClient {
let startedAt: String
let finishedAt: String
let language: String
- let timezone: String
- let inputDeviceName: String?
enum CodingKeys: String, CodingKey {
case transcriptSegments = "transcript_segments"
@@ -1196,8 +1190,6 @@ extension APIClient {
case startedAt = "started_at"
case finishedAt = "finished_at"
case language
- case timezone
- case inputDeviceName = "input_device_name"
}
}
@@ -1233,16 +1225,12 @@ extension APIClient {
/// - finishedAt: When the recording finished
/// - source: Source of the conversation (e.g., "desktop", "omi", "bee")
/// - language: Language code for transcription
- /// - timezone: User's timezone
- /// - inputDeviceName: Name of the input device (microphone or BLE device)
func createConversationFromSegments(
segments: [TranscriptSegmentRequest],
startedAt: Date,
finishedAt: Date,
source: ConversationSource = .desktop,
- language: String = "en",
- timezone: String = "UTC",
- inputDeviceName: String? = nil
+ language: String = "en"
) async throws -> CreateConversationResponse {
let formatter = ISO8601DateFormatter()
formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds]
@@ -1252,9 +1240,7 @@ extension APIClient {
source: source.rawValue,
startedAt: formatter.string(from: startedAt),
finishedAt: formatter.string(from: finishedAt),
- language: language,
- timezone: timezone,
- inputDeviceName: inputDeviceName
+ language: language
)
return try await post("v1/conversations/from-segments", body: request)
@@ -1464,14 +1450,26 @@ struct UserProfile: Codable {
// MARK: - Action Items API
/// Response wrapper for paginated action items list
-struct ActionItemsListResponse: Codable {
+struct ActionItemsListResponse: Decodable {
let items: [TaskActionItem]
let hasMore: Bool
enum CodingKeys: String, CodingKey {
+ case actionItems = "action_items"
case items
case hasMore = "has_more"
}
+
+ init(from decoder: Decoder) throws {
+ let container = try decoder.container(keyedBy: CodingKeys.self)
+ hasMore = try container.decodeIfPresent(Bool.self, forKey: .hasMore) ?? false
+ // Python action_items endpoint returns "action_items"; staged-tasks returns "items"
+ if let actionItems = try container.decodeIfPresent([TaskActionItem].self, forKey: .actionItems) {
+ items = actionItems
+ } else {
+ items = try container.decodeIfPresent([TaskActionItem].self, forKey: .items) ?? []
+ }
+ }
}
extension APIClient {
@@ -1927,17 +1925,6 @@ extension APIClient {
return try await post("v1/staged-tasks/promote")
}
- /// One-time migration of existing AI tasks to staged_tasks
- func migrateStagedTasks() async throws {
- struct StatusResponse: Decodable { let status: String }
- let _: StatusResponse = try await post("v1/staged-tasks/migrate")
- }
-
- /// Migrate conversation-extracted action items (no source field) to staged_tasks
- func migrateConversationItemsToStaged() async throws {
- struct MigrateResponse: Decodable { let status: String; let migrated: Int; let deleted: Int }
- let _: MigrateResponse = try await post("v1/staged-tasks/migrate-conversation-items")
- }
}
/// Response for staged task promotion
@@ -3185,13 +3172,12 @@ extension APIClient {
/// Regenerates persona prompt from current public memories
func regeneratePersonaPrompt() async throws -> GeneratePromptResponse {
- struct EmptyRequest: Encodable {}
- return try await post("v1/personas/generate-prompt", body: EmptyRequest())
+ return try await get("v1/app/generate-prompts")
}
/// Checks if a username is available
func checkPersonaUsername(_ username: String) async throws -> UsernameAvailableResponse {
- return try await get("v1/personas/check-username?username=\(username)")
+ return try await get("v1/apps/check-username?username=\(username)")
}
}
@@ -3289,9 +3275,27 @@ struct GeneratePromptResponse: Codable {
}
/// Response for username availability check
-struct UsernameAvailableResponse: Codable {
+struct UsernameAvailableResponse: Decodable {
let available: Bool
- let username: String
+ let username: String?
+ let isTaken: Bool?
+
+ enum CodingKeys: String, CodingKey {
+ case available, username
+ case isTaken = "is_taken"
+ }
+
+ init(from decoder: Decoder) throws {
+ let container = try decoder.container(keyedBy: CodingKeys.self)
+ username = try container.decodeIfPresent(String.self, forKey: .username)
+ isTaken = try container.decodeIfPresent(Bool.self, forKey: .isTaken)
+ // Python returns is_taken; Rust returned available. Support both.
+ if let isTaken = isTaken {
+ available = !isTaken
+ } else {
+ available = try container.decodeIfPresent(Bool.self, forKey: .available) ?? true
+ }
+ }
}
// MARK: - User Settings API
@@ -3967,7 +3971,7 @@ extension APIClient {
let metadata: String?
}
let body = SaveRequest(text: text, sender: sender, app_id: appId, session_id: sessionId, metadata: metadata)
- return try await post("v2/messages", body: body)
+ return try await post("v2/messages/save", body: body)
}
/// Fetch chat message history
@@ -4129,7 +4133,7 @@ extension APIClient {
}
let body = InitialMessageRequest(sessionId: sessionId, appId: appId)
- return try await post("v2/chat/initial-message", body: body)
+ return try await post("v2/initial-message", body: body)
}
/// Generate a title for a chat session based on its messages
@@ -4278,31 +4282,51 @@ extension APIClient {
// MARK: - Agent VM
struct AgentProvisionResponse: Decodable {
- let status: String
- let vmName: String
+ let hasVm: Bool
+ let status: String?
+ let vmName: String?
let ip: String?
- let authToken: String
- let agentStatus: String
+ let authToken: String?
+ let agentStatus: String?
+
+ enum CodingKeys: String, CodingKey {
+ case hasVm = "has_vm"
+ case status
+ case vmName = "vm_name"
+ case ip
+ case authToken = "auth_token"
+ case agentStatus = "agent_status"
+ }
}
/// Provision a cloud agent VM for the current user (fire-and-forget)
func provisionAgentVM() async throws -> AgentProvisionResponse {
- return try await post("v2/agent/provision")
+ return try await post("v1/agent/vm-ensure")
}
struct AgentStatusResponse: Decodable {
- let vmName: String
- let zone: String
+ let hasVm: Bool
+ let vmName: String?
+ let zone: String?
let ip: String?
- let status: String
- let authToken: String
- let createdAt: String
+ let status: String?
+ let authToken: String?
+ let createdAt: String?
let lastQueryAt: String?
+
+ enum CodingKeys: String, CodingKey {
+ case hasVm = "has_vm"
+ case vmName = "vm_name"
+ case zone, ip, status
+ case authToken = "auth_token"
+ case createdAt = "created_at"
+ case lastQueryAt = "last_query_at"
+ }
}
/// Get current agent VM status
func getAgentStatus() async throws -> AgentStatusResponse? {
- return try await get("v2/agent/status")
+ return try await get("v1/agent/vm-status")
}
}
@@ -4407,41 +4431,13 @@ extension APIClient {
costUsd: Double,
account: String = "omi"
) async {
- struct Req: Encodable {
- let input_tokens: Int
- let output_tokens: Int
- let cache_read_tokens: Int
- let cache_write_tokens: Int
- let total_tokens: Int
- let cost_usd: Double
- let account: String
- }
- struct Res: Decodable { let status: String }
- do {
- let _: Res = try await post("v1/users/me/llm-usage", body: Req(
- input_tokens: inputTokens,
- output_tokens: outputTokens,
- cache_read_tokens: cacheReadTokens,
- cache_write_tokens: cacheWriteTokens,
- total_tokens: totalTokens,
- cost_usd: costUsd,
- account: account
- ))
- } catch {
- log("APIClient: LLM usage record failed: \(error.localizedDescription)")
- }
+ // No-op: LLM usage tracking endpoint not available in Python backend
+ log("APIClient: recordLlmUsage no-op (endpoint removed)")
}
func fetchTotalOmiAICost() async -> Double? {
- struct Res: Decodable { let total_cost_usd: Double }
- do {
- log("APIClient: Fetching total Omi AI cost from backend")
- let res: Res = try await get("v1/users/me/llm-usage/total")
- log("APIClient: Total Omi AI cost from backend: $\(String(format: "%.4f", res.total_cost_usd))")
- return res.total_cost_usd
- } catch {
- log("APIClient: LLM total cost fetch failed: \(error.localizedDescription)")
- return nil
- }
+ // No-op: LLM usage total endpoint not available in Python backend
+ log("APIClient: fetchTotalOmiAICost no-op (endpoint removed)")
+ return nil
}
}
diff --git a/desktop/Desktop/Sources/AgentVMService.swift b/desktop/Desktop/Sources/AgentVMService.swift
index ceec6ca19f..985be922cb 100644
--- a/desktop/Desktop/Sources/AgentVMService.swift
+++ b/desktop/Desktop/Sources/AgentVMService.swift
@@ -23,26 +23,28 @@ actor AgentVMService {
do {
let status = try await APIClient.shared.getAgentStatus()
if let status = status, status.status == "ready", let ip = status.ip {
- log("AgentVMService: VM already ready — vmName=\(status.vmName) ip=\(ip)")
+ let token = status.authToken ?? ""
+ log("AgentVMService: VM already ready — vmName=\(status.vmName ?? "unknown") ip=\(ip)")
// Only upload if the VM doesn't have a database yet
- if await checkVMNeedsDatabase(vmIP: ip, authToken: status.authToken) {
- await uploadDatabase(vmIP: ip, authToken: status.authToken)
+ if await checkVMNeedsDatabase(vmIP: ip, authToken: token) {
+ await uploadDatabase(vmIP: ip, authToken: token)
} else {
log("AgentVMService: VM already has database, skipping upload")
}
- await startIncrementalSync(vmIP: ip, authToken: status.authToken)
+ await startIncrementalSync(vmIP: ip, authToken: token)
return
}
if let status = status,
status.status == "provisioning" || status.status == "stopped" {
- log("AgentVMService: VM is \(status.status), polling until ready...")
+ log("AgentVMService: VM is \(status.status ?? "unknown"), polling until ready...")
if let result = await pollUntilReady(maxAttempts: 30, intervalSeconds: 5),
let ip = result.ip {
+ let token = result.authToken ?? ""
log("AgentVMService: VM became ready — ip=\(ip)")
- if await checkVMNeedsDatabase(vmIP: ip, authToken: result.authToken) {
- await uploadDatabase(vmIP: ip, authToken: result.authToken)
+ if await checkVMNeedsDatabase(vmIP: ip, authToken: token) {
+ await uploadDatabase(vmIP: ip, authToken: token)
}
- await startIncrementalSync(vmIP: ip, authToken: result.authToken)
+ await startIncrementalSync(vmIP: ip, authToken: token)
}
return
}
@@ -76,7 +78,7 @@ actor AgentVMService {
let provisionResult: APIClient.AgentProvisionResponse
do {
provisionResult = try await APIClient.shared.provisionAgentVM()
- log("AgentVMService: Provision response — vmName=\(provisionResult.vmName) status=\(provisionResult.status) ip=\(provisionResult.ip ?? "none")")
+ log("AgentVMService: Provision response — vmName=\(provisionResult.vmName ?? "unknown") status=\(provisionResult.status ?? "unknown") ip=\(provisionResult.ip ?? "none")")
} catch {
log("AgentVMService: Provision failed — \(error.localizedDescription)")
return
@@ -84,14 +86,14 @@ actor AgentVMService {
// Step 2: Poll until VM is ready with an IP
var vmIP = provisionResult.ip
- var authToken = provisionResult.authToken
+ var authToken = provisionResult.authToken ?? ""
- if vmIP == nil || provisionResult.agentStatus == "provisioning" {
+ if vmIP == nil || provisionResult.status == "provisioning" {
log("AgentVMService: Waiting for VM to be ready...")
let pollResult = await pollUntilReady(maxAttempts: 30, intervalSeconds: 5)
if let result = pollResult {
vmIP = result.ip
- authToken = result.authToken
+ authToken = result.authToken ?? ""
log("AgentVMService: VM ready — ip=\(vmIP ?? "none")")
} else {
log("AgentVMService: VM did not become ready in time")
@@ -111,7 +113,7 @@ actor AgentVMService {
await startIncrementalSync(vmIP: ip, authToken: authToken)
}
- /// Poll GET /v2/agent/status until status is "ready" and IP is available.
+ /// Poll GET /v1/agent/vm-status until status is "ready" and IP is available.
private func pollUntilReady(maxAttempts: Int, intervalSeconds: UInt64) async -> APIClient.AgentStatusResponse? {
for attempt in 1...maxAttempts {
do {
diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift
index 862bfaf070..f136e4ed1f 100644
--- a/desktop/Desktop/Sources/AppState.swift
+++ b/desktop/Desktop/Sources/AppState.swift
@@ -135,13 +135,16 @@ class AppState: ObservableObject {
// Transcription services
private var audioCaptureService: AudioCaptureService?
- private var transcriptionService: TranscriptionService?
+ private var transcriptionService: BackendTranscriptionService?
private var systemAudioCaptureService: Any? // SystemAudioCaptureService (macOS 14.4+)
private var audioMixer: AudioMixer?
private var vadGateService: VADGateService?
- // Batch transcription mode
+ // Batch transcription mode (disabled — backend handles everything via /v4/listen)
private var useBatchTranscription: Bool = false
+ // When true, backend owns conversation creation via /v4/listen lifecycle manager.
+ // Desktop skips createConversationFromSegments() to avoid duplicates.
+ private var backendOwnsConversation: Bool = false
private var recordingStartCATime: Double = 0 // CACurrentMediaTime at recording start
// Speaker segments for diarized transcription (sliding window — older segments are in SQLite)
@@ -430,12 +433,7 @@ class AppState: ObservableObject {
}
}
- // Log final state of important keys
- if getenv("DEEPGRAM_API_KEY") != nil {
- log("DEEPGRAM_API_KEY is set")
- } else {
- log("WARNING: DEEPGRAM_API_KEY is NOT set")
- }
+ // DEEPGRAM_API_KEY no longer needed — STT routed through backend /v4/listen
}
private func shouldSkipBundledAnthropicKey(key: String, sourcePath: String, bundledEnvPath: String?) -> Bool {
@@ -1152,165 +1150,143 @@ class AppState: ObservableObject {
}
}
- do {
- // Get effective language from settings (handles auto-detect vs single language)
- let effectiveLanguage = AssistantSettings.shared.effectiveTranscriptionLanguage
- let vocabulary = AssistantSettings.shared.effectiveVocabulary
- log("Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))")
- log("Transcription: Custom vocabulary: \(vocabulary.joined(separator: ", "))")
-
- // Determine transcription mode
- useBatchTranscription = AssistantSettings.shared.batchTranscriptionEnabled && effectiveSource == .microphone
-
- if !useBatchTranscription {
- // Streaming mode: initialize WebSocket transcription service
- transcriptionService = try TranscriptionService(language: effectiveLanguage, vocabulary: vocabulary)
+ // Get effective language from settings (handles auto-detect vs single language)
+ let effectiveLanguage = AssistantSettings.shared.effectiveTranscriptionLanguage
+ let vocabulary = AssistantSettings.shared.effectiveVocabulary
+ log("Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))")
+ log("Transcription: Custom vocabulary: \(vocabulary.joined(separator: ", "))")
+
+ // Always use streaming mode through the backend — batch mode not needed
+ // (backend handles STT, diarization, and memory creation server-side)
+ useBatchTranscription = false
+ // Backend owns conversation creation via /v4/listen lifecycle manager
+ backendOwnsConversation = true
+
+ // Set conversation source based on audio source
+ let sourceValue: String
+ if effectiveSource == .bleDevice, let device = DeviceProvider.shared.connectedDevice {
+ currentConversationSource = ConversationSource.from(deviceType: device.type)
+ recordingInputDeviceName = device.displayName
+ sourceValue = currentConversationSource.rawValue
+ } else {
+ currentConversationSource = .desktop
+ recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName()
+ sourceValue = "desktop"
+ }
+
+ transcriptionService = BackendTranscriptionService(language: effectiveLanguage, source: sourceValue)
+
+ // Initialize audio services based on source
+ if effectiveSource == .microphone {
+ // Initialize audio capture service
+ audioCaptureService = AudioCaptureService()
+
+ // Initialize audio mixer for combining mic and system audio
+ audioMixer = AudioMixer()
+
+ // VAD gate is optional for streaming mode (silence gating)
+ if AssistantSettings.shared.vadGateEnabled {
+ let gate = VADGateService()
+ vadGateService = gate
+ log("Transcription: VAD gate enabled")
} else {
- log("Transcription: Batch mode enabled — skipping WebSocket")
+ vadGateService = nil
+ }
+
+ // Initialize system audio capture if supported (macOS 14.4+)
+ // Can be disabled via: defaults write com.omi.desktop-dev disableSystemAudioCapture -bool true
+ // or: defaults write com.omi.computer-macos disableSystemAudioCapture -bool true
+ let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture")
+ if systemAudioDisabled {
+ log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)")
+ } else if #available(macOS 14.4, *) {
+ systemAudioCaptureService = SystemAudioCaptureService()
+ log("Transcription: System audio capture initialized (macOS 14.4+)")
+ } else {
+ log("Transcription: System audio capture not available (requires macOS 14.4+)")
}
+ }
+ // For BLE device, BleAudioService will be used in startAudioCapture
- // Set conversation source based on audio source
- if effectiveSource == .bleDevice, let device = DeviceProvider.shared.connectedDevice {
- currentConversationSource = ConversationSource.from(deviceType: device.type)
- recordingInputDeviceName = device.displayName
- } else {
- currentConversationSource = .desktop
- recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName()
- }
-
- // Initialize audio services based on source
- if effectiveSource == .microphone {
- // Initialize audio capture service
- audioCaptureService = AudioCaptureService()
-
- // Initialize audio mixer for combining mic and system audio
- audioMixer = AudioMixer()
-
- // VAD gate is always needed for batch mode (chunk boundaries),
- // and optional for streaming mode (silence gating)
- if useBatchTranscription || AssistantSettings.shared.vadGateEnabled {
- let gate = VADGateService()
- if useBatchTranscription && !gate.modelAvailable {
- // Batch mode requires working VAD — fall back to streaming
- log("Transcription: VAD models unavailable, falling back from batch to streaming mode")
- useBatchTranscription = false
- vadGateService = nil
- transcriptionService = try TranscriptionService(language: effectiveLanguage, vocabulary: vocabulary)
- } else {
- vadGateService = gate
- log("Transcription: VAD gate enabled\(useBatchTranscription ? " (batch mode)" : "")")
- }
- } else {
- vadGateService = nil
+ // Start backend transcription service, then audio on connect
+ transcriptionService?.start(
+ onTranscript: { [weak self] segment in
+ Task { @MainActor in
+ self?.handleTranscriptSegment(segment)
}
-
- // Initialize system audio capture if supported (macOS 14.4+)
- // Can be disabled via: defaults write com.omi.desktop-dev disableSystemAudioCapture -bool true
- // or: defaults write com.omi.computer-macos disableSystemAudioCapture -bool true
- let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture")
- if systemAudioDisabled {
- log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)")
- } else if #available(macOS 14.4, *) {
- systemAudioCaptureService = SystemAudioCaptureService()
- log("Transcription: System audio capture initialized (macOS 14.4+)")
- } else {
- log("Transcription: System audio capture not available (requires macOS 14.4+)")
+ },
+ onError: { [weak self] error in
+ Task { @MainActor in
+ logError("Transcription error", error: error)
+ AnalyticsManager.shared.recordingError(error: error.localizedDescription)
+ self?.stopTranscription()
}
- }
- // For BLE device, BleAudioService will be used in startAudioCapture
-
- if useBatchTranscription {
- // Batch mode: start audio capture directly (no WebSocket to wait for)
- recordingStartCATime = CACurrentMediaTime()
- Task { @MainActor [weak self] in
+ },
+ onConnected: { [weak self] in
+ Task { @MainActor in
+ log("Transcription: Connected to backend")
+ // Start audio capture once connected
await self?.startAudioCapture(source: effectiveSource)
}
- } else {
- // Streaming mode: start transcription service first, then audio on connect
- transcriptionService?.start(
- onTranscript: { [weak self] segment in
- Task { @MainActor in
- self?.handleTranscriptSegment(segment)
- }
- },
- onError: { [weak self] error in
- Task { @MainActor in
- logError("Transcription error", error: error)
- AnalyticsManager.shared.recordingError(error: error.localizedDescription)
- self?.stopTranscription()
- }
- },
- onConnected: { [weak self] in
- Task { @MainActor in
- log("Transcription: Connected to DeepGram")
- // Start audio capture once connected
- await self?.startAudioCapture(source: effectiveSource)
- }
- },
- onDisconnected: {
- log("Transcription: Disconnected from DeepGram")
- }
- )
+ },
+ onDisconnected: {
+ log("Transcription: Disconnected from backend")
}
+ )
- isTranscribing = true
- AssistantSettings.shared.transcriptionEnabled = true
- audioSource = effectiveSource
- currentTranscript = ""
- speakerSegments = []
- totalSegmentCount = 0
- totalWordCount = 0
- liveSpeakerPersonMap = [:]
- LiveTranscriptMonitor.shared.clear()
- recordingStartTime = Date()
- AudioLevelMonitor.shared.reset()
- RecordingTimer.shared.start()
+ isTranscribing = true
+ AssistantSettings.shared.transcriptionEnabled = true
+ audioSource = effectiveSource
+ currentTranscript = ""
+ speakerSegments = []
+ totalSegmentCount = 0
+ totalWordCount = 0
+ liveSpeakerPersonMap = [:]
+ LiveTranscriptMonitor.shared.clear()
+ recordingStartTime = Date()
+ AudioLevelMonitor.shared.reset()
+ RecordingTimer.shared.start()
- log("Transcription: Using source: \(effectiveSource.rawValue), device: \(recordingInputDeviceName ?? "Unknown")")
+ log("Transcription: Using source: \(effectiveSource.rawValue), device: \(recordingInputDeviceName ?? "Unknown")")
- // Create crash-safe DB session for persistence
- Task {
- do {
- let sessionId = try await TranscriptionStorage.shared.startSession(
- source: currentConversationSource.rawValue,
- language: effectiveLanguage,
- timezone: TimeZone.current.identifier,
- inputDeviceName: recordingInputDeviceName
- )
- await MainActor.run {
- self.currentSessionId = sessionId
- // Start live notes session
- LiveNotesMonitor.shared.startSession(sessionId: sessionId)
- }
- log("Transcription: Created DB session \(sessionId)")
- } catch {
- logError("Transcription: Failed to create DB session", error: error)
- // Non-fatal - continue recording even if DB fails
+ // Create crash-safe DB session for persistence
+ Task {
+ do {
+ let sessionId = try await TranscriptionStorage.shared.startSession(
+ source: currentConversationSource.rawValue,
+ language: effectiveLanguage,
+ timezone: TimeZone.current.identifier,
+ inputDeviceName: recordingInputDeviceName
+ )
+ await MainActor.run {
+ self.currentSessionId = sessionId
+ // Start live notes session
+ LiveNotesMonitor.shared.startSession(sessionId: sessionId)
}
+ log("Transcription: Created DB session \(sessionId)")
+ } catch {
+ logError("Transcription: Failed to create DB session", error: error)
+ // Non-fatal - continue recording even if DB fails
}
+ }
- // Start 4-hour max recording timer
- maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { [weak self] _ in
- Task { @MainActor in
- guard let self = self, self.isTranscribing else { return }
- log("Transcription: 4-hour limit reached - finalizing conversation")
- _ = await self.finalizeConversation()
- // Start a new recording session automatically
- self.stopAudioCapture()
- self.clearTranscriptionState()
- self.startTranscription()
- }
+ // Start 4-hour max recording timer
+ maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { [weak self] _ in
+ Task { @MainActor in
+ guard let self = self, self.isTranscribing else { return }
+ log("Transcription: 4-hour limit reached - finalizing conversation")
+ _ = await self.finalizeConversation()
+ // Start a new recording session automatically
+ self.stopAudioCapture()
+ self.clearTranscriptionState()
+ self.startTranscription()
}
+ }
- // Track transcription started
- AnalyticsManager.shared.transcriptionStarted()
-
- log("Transcription: Starting...")
+ // Track transcription started
+ AnalyticsManager.shared.transcriptionStarted()
- } catch {
- AnalyticsManager.shared.recordingError(error: error.localizedDescription)
- showAlert(title: "Transcription Error", message: error.localizedDescription)
- }
+ log("Transcription: Starting...")
}
/// Start audio capture and pipe to transcription service
@@ -1330,23 +1306,12 @@ class AppState: ObservableObject {
guard let audioCaptureService = audioCaptureService,
let audioMixer = audioMixer else { return }
- // Start the audio mixer - it will send stereo audio to transcription service
- // Branch on batch vs streaming mode
- audioMixer.start { [weak self] stereoData in
+ // Start the audio mixer in mono mode — backend handles diarization server-side
+ audioMixer.start(outputMode: .mono) { [weak self] monoData in
guard let self = self else { return }
- if self.useBatchTranscription {
- // Batch mode: accumulate audio in VAD gate, transcribe on silence
- guard let gate = self.vadGateService else { return }
- let output = gate.processAudioBatch(stereoData)
- if output.isComplete, let audioBuffer = output.audioBuffer {
- let wallStartTime = output.speechStartWallTime
- Task { @MainActor [weak self] in
- await self?.batchTranscribeChunk(audioBuffer: audioBuffer, wallStartTime: wallStartTime)
- }
- }
- } else if let gate = self.vadGateService {
+ if let gate = self.vadGateService {
// Streaming mode with VAD gate
- let output = gate.processAudio(stereoData)
+ let output = gate.processAudio(monoData)
if !output.audioToSend.isEmpty {
self.transcriptionService?.sendAudio(output.audioToSend)
} else if gate.needsKeepalive() {
@@ -1357,7 +1322,7 @@ class AppState: ObservableObject {
}
} else {
// Streaming mode without VAD gate
- self.transcriptionService?.sendAudio(stereoData)
+ self.transcriptionService?.sendAudio(monoData)
}
}
@@ -1411,10 +1376,12 @@ class AppState: ObservableObject {
return
}
- // Start BLE audio processing and pipe directly to transcription
+ // Start BLE audio processing and pipe mono PCM directly to backend transcription
await BleAudioService.shared.startProcessing(
from: connection,
- transcriptionService: transcriptionService,
+ audioSink: { [weak transcriptionService] pcmData in
+ transcriptionService?.sendAudio(pcmData)
+ },
audioDataHandler: { _ in
// Audio level is updated by BleAudioService
Task { @MainActor in
@@ -2095,6 +2062,19 @@ class AppState: ObservableObject {
log("Transcription: Finalizing conversation with \(segmentsToUpload.count) segments")
+ // When backend owns conversation creation (via /v4/listen lifecycle manager),
+ // skip client-side createConversationFromSegments() to avoid duplicates.
+ // The backend already has all segments from the live stream and will process
+ // the conversation on timeout or next connection.
+ if backendOwnsConversation {
+ log("Transcription: Backend owns conversation — skipping client-side upload (\(segmentsToUpload.count) segments streamed)")
+ if let sessionId = sessionId {
+ // Mark session as completed — no retry needed since backend has the data
+ try? await TranscriptionStorage.shared.markSessionCompleted(id: sessionId, backendId: "backend-owned")
+ }
+ return .saved
+ }
+
// Convert SpeakerSegment to API request format (include person_id from live naming)
let speakerPersonMap = liveSpeakerPersonMap
let apiSegments = segmentsToUpload.map { segment in
@@ -2119,8 +2099,7 @@ class AppState: ObservableObject {
segments: apiSegments,
startedAt: startTime,
finishedAt: endTime,
- source: currentConversationSource,
- inputDeviceName: recordingInputDeviceName
+ source: currentConversationSource
)
log("Transcription: Conversation saved - id=\(response.id), status=\(response.status), discarded=\(response.discarded), source=\(currentConversationSource.rawValue), device=\(recordingInputDeviceName ?? "Unknown")")
diff --git a/desktop/Desktop/Sources/Audio/AudioSourceManager.swift b/desktop/Desktop/Sources/Audio/AudioSourceManager.swift
index 9444be7d99..3247adc62c 100644
--- a/desktop/Desktop/Sources/Audio/AudioSourceManager.swift
+++ b/desktop/Desktop/Sources/Audio/AudioSourceManager.swift
@@ -301,7 +301,6 @@ final class AudioSourceManager: ObservableObject {
// Start BLE audio processing with direct audio callback and WAL recording
await bleAudioService.startProcessing(
from: connection,
- transcriptionService: nil, // We'll handle routing ourselves
audioDataHandler: { [weak self] pcmData in
// Convert decoded PCM mono to stereo and forward
self?.handleBleAudio(pcmData)
diff --git a/desktop/Desktop/Sources/Audio/BleAudioService.swift b/desktop/Desktop/Sources/Audio/BleAudioService.swift
index 0cb9bb527f..9078653668 100644
--- a/desktop/Desktop/Sources/Audio/BleAudioService.swift
+++ b/desktop/Desktop/Sources/Audio/BleAudioService.swift
@@ -27,7 +27,7 @@ final class BleAudioService: ObservableObject {
private var cancellables = Set()
// Audio delivery
- private var transcriptionService: TranscriptionService?
+ private var audioSink: ((Data) -> Void)?
private var audioDataHandler: ((Data) -> Void)?
private var rawFrameHandler: ((Data) -> Void)?
@@ -44,12 +44,12 @@ final class BleAudioService: ObservableObject {
/// Start processing audio from a device connection
/// - Parameters:
/// - connection: The device connection to get audio from
- /// - transcriptionService: Optional transcription service to send audio to
+ /// - audioSink: Optional closure to receive decoded mono PCM audio (e.g., send to transcription service)
/// - audioDataHandler: Optional handler for decoded PCM data (alternative to transcription)
/// - rawFrameHandler: Optional handler for raw encoded frames (for WAL recording)
func startProcessing(
from connection: DeviceConnection,
- transcriptionService: TranscriptionService? = nil,
+ audioSink: ((Data) -> Void)? = nil,
audioDataHandler: ((Data) -> Void)? = nil,
rawFrameHandler: ((Data) -> Void)? = nil
) async {
@@ -58,7 +58,7 @@ final class BleAudioService: ObservableObject {
return
}
- self.transcriptionService = transcriptionService
+ self.audioSink = audioSink
self.audioDataHandler = audioDataHandler
self.rawFrameHandler = rawFrameHandler
@@ -126,7 +126,7 @@ final class BleAudioService: ObservableObject {
cancellables.removeAll()
isProcessing = false
- transcriptionService = nil
+ audioSink = nil
audioDataHandler = nil
rawFrameHandler = nil
@@ -194,37 +194,13 @@ final class BleAudioService: ObservableObject {
// Calculate audio level
updateAudioLevel(from: pcmData)
- // Send to transcription service (mono channel)
- if let transcription = transcriptionService {
- // TranscriptionService expects stereo (2 channels) for multichannel transcription
- // For BLE device audio, we duplicate to both channels (device is the "user")
- let stereoData = convertToStereo(pcmData)
- transcription.sendAudio(stereoData)
- }
+ // Send decoded mono PCM to audio sink (e.g., transcription service)
+ audioSink?(pcmData)
// Send to custom handler
audioDataHandler?(pcmData)
}
- /// Convert mono PCM to stereo (duplicate to both channels)
- private func convertToStereo(_ monoData: Data) -> Data {
- // Mono: [S0, S1, S2, ...]
- // Stereo: [S0, S0, S1, S1, S2, S2, ...] (interleaved)
- var stereoData = Data(capacity: monoData.count * 2)
-
- monoData.withUnsafeBytes { bytes in
- let samples = bytes.bindMemory(to: Int16.self)
- for i in 0.. Void
// MARK: - Properties
private var onStereoChunk: StereoAudioHandler?
private var isRunning = false
+ private(set) var outputMode: OutputMode = .stereo
// Audio buffers (16kHz mono Int16 PCM)
private var micBuffer = Data()
@@ -29,15 +36,18 @@ class AudioMixer {
// MARK: - Public Methods
/// Start the mixer
- /// - Parameter onStereoChunk: Callback receiving interleaved stereo 16-bit PCM at 16kHz
- func start(onStereoChunk: @escaping StereoAudioHandler) {
+ /// - Parameters:
+ /// - outputMode: `.stereo` for interleaved multichannel, `.mono` for averaged single-channel
+ /// - onStereoChunk: Callback receiving mixed 16-bit PCM at 16kHz
+ func start(outputMode: OutputMode = .stereo, onStereoChunk: @escaping StereoAudioHandler) {
bufferLock.lock()
+ self.outputMode = outputMode
self.onStereoChunk = onStereoChunk
self.isRunning = true
micBuffer = Data()
systemBuffer = Data()
bufferLock.unlock()
- log("AudioMixer: Started")
+ log("AudioMixer: Started (output=\(outputMode))")
}
/// Stop the mixer and flush remaining audio
@@ -105,12 +115,17 @@ class AudioMixer {
if flush {
// When flushing, process whatever is available
bytesToProcess = max(micBuffer.count, systemBuffer.count)
+ } else if micBuffer.count >= minBufferBytes && systemBuffer.count >= minBufferBytes {
+ // Both buffers have data — use shorter to stay in sync
+ bytesToProcess = (min(micBuffer.count, systemBuffer.count) / 2) * 2
+ } else if micBuffer.count >= minBufferBytes {
+ // Only mic has data (system audio disabled/unavailable) — pad system with silence
+ bytesToProcess = (micBuffer.count / 2) * 2
+ } else if systemBuffer.count >= minBufferBytes {
+ // Only system has data — pad mic with silence
+ bytesToProcess = (systemBuffer.count / 2) * 2
} else {
- // Normal operation: process when both have data
- let minAvailable = min(micBuffer.count, systemBuffer.count)
- guard minAvailable >= minBufferBytes else { return }
- // Align to sample boundary (2 bytes per Int16 sample)
- bytesToProcess = (minAvailable / 2) * 2
+ return
}
guard bytesToProcess >= 2 else { return }
@@ -137,11 +152,17 @@ class AudioMixer {
systemBuffer = Data()
}
- // Interleave into stereo
- let stereoData = interleave(mic: micData, system: sysData)
+ // Mix according to output mode
+ let mixedData: Data
+ switch outputMode {
+ case .stereo:
+ mixedData = interleave(mic: micData, system: sysData)
+ case .mono:
+ mixedData = mixToMono(mic: micData, system: sysData)
+ }
// Send to callback
- onStereoChunk?(stereoData)
+ onStereoChunk?(mixedData)
}
/// Interleave two mono Int16 streams into stereo
@@ -174,4 +195,32 @@ class AudioMixer {
Data(buffer: buffer)
}
}
+
+ /// Average two mono Int16 streams into a single mono stream
+ /// Output format: [(mic0+sys0)/2, (mic1+sys1)/2, ...]
+ private func mixToMono(mic: Data, system: Data) -> Data {
+ let sampleCount = mic.count / 2
+
+ var monoSamples = [Int16]()
+ monoSamples.reserveCapacity(sampleCount)
+
+ mic.withUnsafeBytes { micPtr in
+ system.withUnsafeBytes { sysPtr in
+ let micSamples = micPtr.bindMemory(to: Int16.self)
+ let sysSamples = sysPtr.bindMemory(to: Int16.self)
+
+ for i in 0.. Void
+ typealias ErrorHandler = (Error) -> Void
+ typealias ConnectionHandler = () -> Void
+
+ enum BackendTranscriptionError: LocalizedError {
+ case notSignedIn
+ case connectionFailed(Error)
+ case invalidResponse
+ case webSocketError(String)
+
+ var errorDescription: String? {
+ switch self {
+ case .notSignedIn:
+ return "Not signed in — cannot connect to backend"
+ case .connectionFailed(let error):
+ return "Connection failed: \(error.localizedDescription)"
+ case .invalidResponse:
+ return "Invalid response from backend"
+ case .webSocketError(let message):
+ return "WebSocket error: \(message)"
+ }
+ }
+ }
+
+ // MARK: - Properties
+
+ private var webSocketTask: URLSessionWebSocketTask?
+ private var urlSession: URLSession?
+ private var isConnected = false
+ private var shouldReconnect = false
+
+ // Callbacks
+ private var onTranscript: TranscriptHandler?
+ private var onError: ErrorHandler?
+ private var onConnected: ConnectionHandler?
+ private var onDisconnected: ConnectionHandler?
+
+ // Configuration
+ private let language: String
+ private let sampleRate = 16000
+ private let codec = "pcm16"
+ private let channels = 1 // Always mono — backend handles diarization
+ private let source: String
+ private let conversationTimeout: Int
+
+ // Reconnection
+ private var reconnectAttempts = 0
+ private let maxReconnectAttempts = 10
+ private var reconnectTask: Task?
+
+ // Keepalive — send empty data periodically to prevent timeout
+ private var keepaliveTask: Task?
+ private let keepaliveInterval: TimeInterval = 8.0
+
+ // Watchdog: detect stale connections where WebSocket dies silently
+ private var watchdogTask: Task?
+ private var lastDataReceivedAt: Date?
+ private var lastKeepaliveSuccessAt: Date?
+ private let watchdogInterval: TimeInterval = 30.0
+ private let staleThreshold: TimeInterval = 60.0
+
+ // Audio buffering
+ private var audioBuffer = Data()
+ private let audioBufferSize = 3200 // ~100ms of 16kHz 16-bit mono (16000 * 2 * 0.1)
+ private let audioBufferLock = NSLock()
+
+ // MARK: - Initialization
+
+ /// Initialize the backend transcription service
+ /// - Parameters:
+ /// - language: Language code for transcription (e.g., "en", "multi")
+ /// - source: Audio source identifier for backend analytics (e.g., "desktop", "omi", "bee")
+ /// - conversationTimeout: Seconds of silence before the backend creates a memory
+ init(language: String = "en", source: String = "desktop", conversationTimeout: Int = 120) {
+ self.language = language
+ self.source = source
+ self.conversationTimeout = conversationTimeout
+ log("BackendTranscriptionService: Initialized with language=\(language), source=\(source)")
+ }
+
+ // MARK: - Public Methods
+
+ /// Start the transcription service
+ func start(
+ onTranscript: @escaping TranscriptHandler,
+ onError: ErrorHandler? = nil,
+ onConnected: ConnectionHandler? = nil,
+ onDisconnected: ConnectionHandler? = nil
+ ) {
+ self.onTranscript = onTranscript
+ self.onError = onError
+ self.onConnected = onConnected
+ self.onDisconnected = onDisconnected
+ self.shouldReconnect = true
+ self.reconnectAttempts = 0
+
+ connect()
+ }
+
+ /// Stop the transcription service
+ func stop() {
+ shouldReconnect = false
+ reconnectTask?.cancel()
+ reconnectTask = nil
+ keepaliveTask?.cancel()
+ keepaliveTask = nil
+ watchdogTask?.cancel()
+ watchdogTask = nil
+
+ flushAudioBuffer()
+ disconnect()
+ }
+
+ /// Signal the backend that no more audio will be sent, but keep connection open
+ /// to receive final transcription results. Call stop() later to fully disconnect.
+ func finishStream() {
+ shouldReconnect = false
+ reconnectTask?.cancel()
+ reconnectTask = nil
+ keepaliveTask?.cancel()
+ keepaliveTask = nil
+ watchdogTask?.cancel()
+ watchdogTask = nil
+
+ flushAudioBuffer()
+
+ // Backend doesn't have a CloseStream message like Deepgram.
+ // The connection will be closed when stop() is called.
+ log("BackendTranscriptionService: finishStream called, waiting for final results")
+ }
+
+ /// Send audio data to the backend (buffered for efficiency)
+ func sendAudio(_ data: Data) {
+ guard isConnected else { return }
+
+ audioBufferLock.lock()
+ audioBuffer.append(data)
+
+ if audioBuffer.count >= audioBufferSize {
+ let chunk = audioBuffer
+ audioBuffer = Data()
+ audioBufferLock.unlock()
+ sendAudioChunk(chunk)
+ } else {
+ audioBufferLock.unlock()
+ }
+ }
+
+ /// Flush any remaining audio in the buffer
+ private func flushAudioBuffer() {
+ audioBufferLock.lock()
+ let chunk = audioBuffer
+ audioBuffer = Data()
+ audioBufferLock.unlock()
+
+ if !chunk.isEmpty {
+ sendAudioChunk(chunk)
+ }
+ }
+
+ /// Actually send an audio chunk over the WebSocket
+ private func sendAudioChunk(_ data: Data) {
+ guard isConnected, let webSocketTask = webSocketTask else { return }
+
+ let message = URLSessionWebSocketTask.Message.data(data)
+ webSocketTask.send(message) { [weak self] error in
+ if let error = error {
+ logError("BackendTranscriptionService: Send error", error: error)
+ self?.handleDisconnection()
+ }
+ }
+ }
+
+ /// No-op for backend (Deepgram-specific Finalize message not needed)
+ func sendFinalize() {
+ // Backend handles segmentation server-side
+ }
+
+ /// Public keepalive for VAD gate to call during extended silence
+ func sendKeepalivePublic() {
+ sendKeepalive()
+ }
+
+ /// Check if connected
+ var connected: Bool {
+ return isConnected
+ }
+
+ // MARK: - Connection
+
+ private func connect() {
+ Task {
+ do {
+ let token = try await AuthService.shared.getIdToken()
+ let baseURL = await APIClient.shared.baseURL
+ self.connectWithToken(token, baseURL: baseURL)
+ } catch {
+ logError("BackendTranscriptionService: Failed to get auth token", error: error)
+ self.onError?(BackendTranscriptionError.notSignedIn)
+ }
+ }
+ }
+
+ private func connectWithToken(_ token: String, baseURL: String) {
+
+ // Convert http(s) to ws(s)
+ let wsBaseURL: String
+ if baseURL.hasPrefix("https://") {
+ wsBaseURL = "wss://" + baseURL.dropFirst("https://".count)
+ } else if baseURL.hasPrefix("http://") {
+ wsBaseURL = "ws://" + baseURL.dropFirst("http://".count)
+ } else {
+ wsBaseURL = "wss://" + baseURL
+ }
+
+ // Strip trailing slash before appending path
+ let cleanBase = wsBaseURL.hasSuffix("/") ? String(wsBaseURL.dropLast()) : wsBaseURL
+
+ var components = URLComponents(string: cleanBase + "/v4/listen")!
+ components.queryItems = [
+ URLQueryItem(name: "language", value: language),
+ URLQueryItem(name: "sample_rate", value: String(sampleRate)),
+ URLQueryItem(name: "codec", value: codec),
+ URLQueryItem(name: "channels", value: String(channels)),
+ URLQueryItem(name: "source", value: source),
+ URLQueryItem(name: "include_speech_profile", value: "true"),
+ URLQueryItem(name: "speaker_auto_assign", value: "enabled"),
+ URLQueryItem(name: "conversation_timeout", value: String(conversationTimeout)),
+ ]
+
+ guard let url = components.url else {
+ onError?(BackendTranscriptionError.connectionFailed(NSError(domain: "Invalid URL", code: -1)))
+ return
+ }
+
+ log("BackendTranscriptionService: Connecting to \(url.absoluteString)")
+
+ // Create URL request with Bearer auth header (same as mobile app)
+ var request = URLRequest(url: url)
+ request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
+
+ // Create URLSession and WebSocket task
+ let configuration = URLSessionConfiguration.default
+ configuration.timeoutIntervalForRequest = 30
+ configuration.timeoutIntervalForResource = 0 // No resource timeout for long-lived WebSocket
+ urlSession = URLSession(configuration: configuration)
+ webSocketTask = urlSession?.webSocketTask(with: request)
+
+ // Start the connection
+ webSocketTask?.resume()
+
+ // Start receiving messages
+ receiveMessage()
+
+ // Mark as connected after a short delay (backend doesn't send a connect confirmation)
+ DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in
+ guard let self = self, self.webSocketTask?.state == .running else { return }
+ self.isConnected = true
+ self.reconnectAttempts = 0
+ self.lastDataReceivedAt = Date()
+ self.lastKeepaliveSuccessAt = Date()
+ log("BackendTranscriptionService: Connected")
+ self.startKeepalive()
+ self.startWatchdog()
+ self.onConnected?()
+ }
+ }
+
+ // MARK: - Keepalive
+
+ private func startKeepalive() {
+ keepaliveTask?.cancel()
+ keepaliveTask = Task { [weak self] in
+ while !Task.isCancelled {
+ try? await Task.sleep(nanoseconds: UInt64(self?.keepaliveInterval ?? 8.0) * 1_000_000_000)
+ guard !Task.isCancelled, let self = self, self.isConnected else { break }
+ self.sendKeepalive()
+ }
+ }
+ }
+
+ private func sendKeepalive() {
+ guard isConnected, let webSocketTask = webSocketTask else { return }
+
+ // Send a small chunk of silence as keepalive (2 bytes of zero = 1 silent sample)
+ let silence = Data(repeating: 0, count: 2)
+ let message = URLSessionWebSocketTask.Message.data(silence)
+ webSocketTask.send(message) { [weak self] error in
+ if let error = error {
+ logError("BackendTranscriptionService: Keepalive error", error: error)
+ self?.handleDisconnection()
+ } else {
+ self?.lastKeepaliveSuccessAt = Date()
+ }
+ }
+ }
+
+ // MARK: - Watchdog
+
+ private func startWatchdog() {
+ watchdogTask?.cancel()
+ watchdogTask = Task { [weak self] in
+ while !Task.isCancelled {
+ try? await Task.sleep(nanoseconds: UInt64(self?.watchdogInterval ?? 30.0) * 1_000_000_000)
+ guard !Task.isCancelled, let self = self, self.isConnected else { break }
+
+ if let lastData = self.lastDataReceivedAt,
+ Date().timeIntervalSince(lastData) > self.staleThreshold {
+ if let lastKeepalive = self.lastKeepaliveSuccessAt,
+ Date().timeIntervalSince(lastKeepalive) < self.staleThreshold {
+ continue
+ }
+ log("BackendTranscriptionService: Watchdog detected stale connection — forcing reconnect")
+ self.handleDisconnection()
+ }
+ }
+ }
+ }
+
+ // MARK: - Disconnect / Reconnect
+
+ private func disconnect() {
+ isConnected = false
+ keepaliveTask?.cancel()
+ keepaliveTask = nil
+ watchdogTask?.cancel()
+ watchdogTask = nil
+ webSocketTask?.cancel(with: .normalClosure, reason: nil)
+ webSocketTask = nil
+ urlSession?.invalidateAndCancel()
+ urlSession = nil
+ log("BackendTranscriptionService: Disconnected")
+ onDisconnected?()
+ }
+
+ private func handleDisconnection() {
+ guard isConnected else { return }
+
+ isConnected = false
+ keepaliveTask?.cancel()
+ keepaliveTask = nil
+ watchdogTask?.cancel()
+ watchdogTask = nil
+ webSocketTask = nil
+ urlSession?.invalidateAndCancel()
+ urlSession = nil
+ onDisconnected?()
+
+ if shouldReconnect && reconnectAttempts < maxReconnectAttempts {
+ reconnectAttempts += 1
+ let delay = min(pow(2.0, Double(reconnectAttempts)), 32.0)
+ log("BackendTranscriptionService: Reconnecting in \(delay)s (attempt \(reconnectAttempts))")
+
+ reconnectTask = Task {
+ try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
+ guard !Task.isCancelled, self.shouldReconnect else { return }
+ self.connect()
+ }
+ } else if reconnectAttempts >= maxReconnectAttempts {
+ log("BackendTranscriptionService: Max reconnect attempts reached")
+ onError?(BackendTranscriptionError.webSocketError("Max reconnect attempts reached"))
+ }
+ }
+
+ // MARK: - Message Handling
+
+ private func receiveMessage() {
+ webSocketTask?.receive { [weak self] result in
+ guard let self = self else { return }
+
+ switch result {
+ case .success(let message):
+ self.handleMessage(message)
+ self.receiveMessage()
+
+ case .failure(let error):
+ guard self.isConnected else { return }
+ logError("BackendTranscriptionService: Receive error", error: error)
+ self.handleDisconnection()
+ }
+ }
+ }
+
+ private func handleMessage(_ message: URLSessionWebSocketTask.Message) {
+ lastDataReceivedAt = Date()
+
+ switch message {
+ case .string(let text):
+ parseResponse(text)
+ case .data(let data):
+ if let text = String(data: data, encoding: .utf8) {
+ parseResponse(text)
+ }
+ @unknown default:
+ break
+ }
+ }
+
+ private func parseResponse(_ text: String) {
+ // Handle heartbeat ping from backend
+ if text == "ping" {
+ return
+ }
+
+ guard let data = text.data(using: .utf8) else { return }
+
+ // Try parsing as array of transcript segments (main response format)
+ if let segments = try? JSONDecoder().decode([BackendSegment].self, from: data) {
+ for segment in segments {
+ // Map backend is_user to channel index:
+ // is_user=true → channelIndex=0 (mic/user)
+ // is_user=false → channelIndex=1 (system/others)
+ let channelIndex = segment.is_user ? 0 : 1
+
+ let transcriptSegment = TranscriptSegment(
+ text: segment.text,
+ isFinal: true,
+ speechFinal: true,
+ confidence: 1.0,
+ words: [TranscriptSegment.Word(
+ word: segment.text,
+ start: segment.start,
+ end: segment.end,
+ confidence: 1.0,
+ speaker: segment.speaker_id,
+ punctuatedWord: segment.text
+ )],
+ channelIndex: channelIndex
+ )
+ onTranscript?(transcriptSegment)
+ }
+ return
+ }
+
+ // Try parsing as event object (memory_created, service_status, etc.)
+ if let event = try? JSONDecoder().decode(BackendEvent.self, from: data) {
+ switch event.type {
+ case "memory_created":
+ log("BackendTranscriptionService: Memory created")
+ case "service_status":
+ log("BackendTranscriptionService: Service status: \(event.status ?? "unknown")")
+ default:
+ log("BackendTranscriptionService: Event: \(event.type)")
+ }
+ return
+ }
+
+ // Unknown message — log for debugging
+ log("BackendTranscriptionService: Unknown message: \(text.prefix(200))")
+ }
+}
+
+// MARK: - Backend Response Models
+
+/// Transcript segment from the OMI backend
+private struct BackendSegment: Decodable {
+ let text: String
+ let speaker: String?
+ let speaker_id: Int?
+ let is_user: Bool
+ let start: Double
+ let end: Double
+ let person_id: String?
+}
+
+/// Event message from the OMI backend
+private struct BackendEvent: Decodable {
+ let type: String
+ let status: String?
+
+ enum CodingKeys: String, CodingKey {
+ case type
+ case status
+ }
+
+ init(from decoder: Decoder) throws {
+ let container = try decoder.container(keyedBy: CodingKeys.self)
+ type = try container.decode(String.self, forKey: .type)
+ status = try container.decodeIfPresent(String.self, forKey: .status)
+ }
+}
diff --git a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift
index aee2f956d7..4156578928 100644
--- a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift
+++ b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift
@@ -34,7 +34,7 @@ class PushToTalkManager: ObservableObject {
private let doubleTapThreshold: TimeInterval = 0.4
// Transcription
- private var transcriptionService: TranscriptionService?
+ private var transcriptionService: BackendTranscriptionService?
private var audioCaptureService: AudioCaptureService?
private var transcriptSegments: [String] = []
private var lastInterimText: String = ""
@@ -302,58 +302,20 @@ class PushToTalkManager: ObservableObject {
sound?.play()
}
- let isBatchMode = ShortcutSettings.shared.pttTranscriptionMode == .batch
+ // Flush remaining audio and wait for final transcript from backend
+ transcriptionService?.finishStream()
+ log("PushToTalkManager: finalizing — mic stopped, waiting for final transcript")
- if isBatchMode {
- // Batch mode: send accumulated audio to pre-recorded API
- log("PushToTalkManager: finalizing (batch) — mic stopped, transcribing recorded audio")
- batchAudioLock.lock()
- let audioData = batchAudioBuffer
- batchAudioBuffer = Data()
- batchAudioLock.unlock()
-
- // Stop streaming service (was not used in batch mode, but clean up)
- stopAudioTranscription()
-
- guard !audioData.isEmpty else {
- log("PushToTalkManager: batch mode — no audio recorded")
- sendTranscript()
- return
- }
-
- barState?.voiceTranscript = "Transcribing..."
-
- Task {
- do {
- let language = AssistantSettings.shared.effectiveTranscriptionLanguage
- let transcript = try await TranscriptionService.batchTranscribe(
- audioData: audioData,
- language: language
- )
- if let transcript, !transcript.isEmpty {
- self.transcriptSegments = [transcript]
- }
- } catch {
- logError("PushToTalkManager: batch transcription failed", error: error)
- }
+ // Safety timeout: if backend doesn't send a final segment within 3s, send what we have
+ let timeout = DispatchWorkItem { [weak self] in
+ Task { @MainActor in
+ guard let self, self.state == .finalizing else { return }
+ log("PushToTalkManager: finalization timeout — sending transcript")
self.sendTranscript()
}
- } else {
- // Live mode: flush remaining audio and wait for final transcript from Deepgram
- transcriptionService?.finishStream()
- log("PushToTalkManager: finalizing (live) — mic stopped, waiting for final transcript")
-
- // Safety timeout: if Deepgram doesn't send a final segment within 3s, send what we have
- let timeout = DispatchWorkItem { [weak self] in
- Task { @MainActor in
- guard let self, self.state == .finalizing else { return }
- log("PushToTalkManager: live finalization timeout — sending transcript")
- self.sendTranscript()
- }
- }
- liveFinalizationTimeout = timeout
- DispatchQueue.main.asyncAfter(deadline: .now() + 3.0, execute: timeout)
}
+ liveFinalizationTimeout = timeout
+ DispatchQueue.main.asyncAfter(deadline: .now() + 3.0, execute: timeout)
}
private func sendTranscript() {
@@ -421,50 +383,34 @@ class PushToTalkManager: ObservableObject {
return
}
- let isBatchMode = ShortcutSettings.shared.pttTranscriptionMode == .batch
+ // Always use live streaming through the backend (no client-side batch mode)
+ startMicCapture()
- if isBatchMode {
- // Batch mode: just capture audio into buffer, no streaming connection
- batchAudioLock.lock()
- batchAudioBuffer = Data()
- batchAudioLock.unlock()
- startMicCapture(batchMode: true)
- log("PushToTalkManager: started audio capture (batch mode)")
- } else {
- // Live mode: start mic capture and stream to Deepgram
- startMicCapture()
+ let language = AssistantSettings.shared.effectiveTranscriptionLanguage
+ let service = BackendTranscriptionService(language: language)
+ transcriptionService = service
- do {
- let language = AssistantSettings.shared.effectiveTranscriptionLanguage
- let service = try TranscriptionService(language: language, channels: 1)
- transcriptionService = service
-
- service.start(
- onTranscript: { [weak self] segment in
- Task { @MainActor in
- self?.handleTranscript(segment)
- }
- },
- onError: { [weak self] error in
- Task { @MainActor in
- logError("PushToTalkManager: transcription error", error: error)
- self?.stopListening()
- }
- },
- onConnected: {
- Task { @MainActor in
- log("PushToTalkManager: DeepGram connected")
- }
- }
- )
- } catch {
- logError("PushToTalkManager: failed to create TranscriptionService", error: error)
- stopListening()
+ service.start(
+ onTranscript: { [weak self] segment in
+ Task { @MainActor in
+ self?.handleTranscript(segment)
+ }
+ },
+ onError: { [weak self] error in
+ Task { @MainActor in
+ logError("PushToTalkManager: transcription error", error: error)
+ self?.stopListening()
+ }
+ },
+ onConnected: {
+ Task { @MainActor in
+ log("PushToTalkManager: backend connected")
+ }
}
- }
+ )
}
- private func startMicCapture(batchMode: Bool = false) {
+ private func startMicCapture() {
if audioCaptureService == nil {
audioCaptureService = AudioCaptureService()
}
@@ -475,20 +421,12 @@ class PushToTalkManager: ObservableObject {
do {
try await capture.startCapture(
onAudioChunk: { [weak self] audioData in
- guard let self else { return }
- if batchMode {
- // Batch mode: accumulate audio in buffer
- self.batchAudioLock.lock()
- self.batchAudioBuffer.append(audioData)
- self.batchAudioLock.unlock()
- } else {
- // Live mode: stream to Deepgram
- self.transcriptionService?.sendAudio(audioData)
- }
+ // Stream mono audio to backend
+ self?.transcriptionService?.sendAudio(audioData)
},
onAudioLevel: { _ in }
)
- log("PushToTalkManager: mic capture started (batch=\(batchMode))")
+ log("PushToTalkManager: mic capture started")
} catch {
logError("PushToTalkManager: mic capture failed", error: error)
self.stopListening()
diff --git a/desktop/Desktop/Sources/GoogleService-Info-Dev.plist b/desktop/Desktop/Sources/GoogleService-Info-Dev.plist
index 9602a49423..c761a117c6 100644
--- a/desktop/Desktop/Sources/GoogleService-Info-Dev.plist
+++ b/desktop/Desktop/Sources/GoogleService-Info-Dev.plist
@@ -9,17 +9,17 @@
ANDROID_CLIENT_ID
208440318997-1ek8tj5oa9ljmnh8tgehk27nqpivivbf.apps.googleusercontent.com
API_KEY
- AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8
+ AIzaSyBK-G7KmEoC72mR10gmQyb2NFBbZyDvcqM
GCM_SENDER_ID
- 208440318997
+ 1031333818730
PLIST_VERSION
1
BUNDLE_ID
com.omi.desktop-dev
PROJECT_ID
- based-hardware
+ based-hardware-dev
STORAGE_BUCKET
- based-hardware.firebasestorage.app
+ based-hardware-dev.firebasestorage.app
IS_ADS_ENABLED
IS_ANALYTICS_ENABLED
@@ -31,6 +31,6 @@
IS_SIGNIN_ENABLED
GOOGLE_APP_ID
- 1:208440318997:ios:a1906bb92fe244810e421c
+ 1:1031333818730:ios:3bea63d8e4f41dbfafb513
\ No newline at end of file
diff --git a/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift b/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift
index f859f973a4..fdcc7944a9 100644
--- a/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift
+++ b/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift
@@ -45,23 +45,12 @@ class LiveNotesMonitor: ObservableObject {
/// Existing notes for context (to avoid repetition)
private var existingNotesContext: [String] = []
- /// GeminiClient for AI generation (lazily initialized)
- private var geminiClient: GeminiClient?
+ /// Backend service for AI generation (injected via configure())
+ private var backendService: BackendProactiveService?
/// Cancellables for subscriptions
private var cancellables = Set()
- /// AI prompt for note generation (from m13v/meeting)
- private let noteGenerationPrompt = """
- generate a single, concise note about what happened in this segment.
- be factual and specific.
- focus on the key point or action item.
- keep it a few word sentence.
- do not use quotes.
- do not use wrapping words like "discussion on", jump straight into note.
- avoid repeating information from existing notes.
- """
-
private init() {
// Subscribe to transcript changes
LiveTranscriptMonitor.shared.$segments
@@ -72,6 +61,12 @@ class LiveNotesMonitor: ObservableObject {
.store(in: &cancellables)
}
+ /// Configure with backend service (call before startSession)
+ func configure(backendService: BackendProactiveService) {
+ self.backendService = backendService
+ log("LiveNotesMonitor: Configured with BackendProactiveService")
+ }
+
// MARK: - Session Lifecycle
/// Start a new notes session
@@ -85,15 +80,8 @@ class LiveNotesMonitor: ObservableObject {
lastProcessedSegmentEnd = nil
existingNotesContext = []
- // Initialize Gemini client if not already done
- if geminiClient == nil {
- do {
- // Use Gemini 3 Pro for better note generation quality
- geminiClient = try GeminiClient(model: "gemini-pro-latest")
- log("LiveNotesMonitor: GeminiClient initialized with gemini-pro-latest")
- } catch {
- logError("LiveNotesMonitor: Failed to initialize GeminiClient", error: error)
- }
+ if backendService == nil {
+ log("LiveNotesMonitor: WARNING — backendService not configured, AI notes disabled")
}
// Load any existing notes from DB (for crash recovery)
@@ -252,10 +240,10 @@ class LiveNotesMonitor: ObservableObject {
}
}
- /// Generate an AI note from recent transcript
+ /// Generate an AI note from recent transcript via backend
private func generateNote(from segments: [SpeakerSegment]) {
guard let sessionId = currentSessionId,
- let client = geminiClient,
+ let service = backendService,
!isGenerating else { return }
isGenerating = true
@@ -265,34 +253,21 @@ class LiveNotesMonitor: ObservableObject {
let segmentStartOrder = max(0, currentSegmentOrder - 3)
let segmentEndOrder = currentSegmentOrder
- // Build context from existing notes
- let existingNotesText = existingNotesContext.isEmpty
- ? "No existing notes yet."
+ // Build session context from existing notes
+ let sessionContext = existingNotesContext.isEmpty
+ ? ""
: "Existing notes:\n" + existingNotesContext.map { "- \($0)" }.joined(separator: "\n")
- let prompt = """
- Transcript segment:
- \(recentText)
-
- \(existingNotesText)
-
- \(noteGenerationPrompt)
- """
-
Task {
do {
- let response = try await client.sendTextRequest(
- prompt: prompt,
- systemPrompt: "You are a concise note-taker. Generate a single short note (3-10 words) about the key point in the transcript. Do not use quotes. Be direct and specific."
- )
+ let noteText = try await service.generateLiveNote(text: recentText, sessionContext: sessionContext)
- // Clean up the response
- let noteText = response
+ let cleaned = noteText
.trimmingCharacters(in: .whitespacesAndNewlines)
.replacingOccurrences(of: "\"", with: "")
.replacingOccurrences(of: "'", with: "")
- guard !noteText.isEmpty else {
+ guard !cleaned.isEmpty else {
await MainActor.run { self.isGenerating = false }
return
}
@@ -300,7 +275,7 @@ class LiveNotesMonitor: ObservableObject {
// Save to DB
let record = try await NoteStorage.shared.createNote(
sessionId: sessionId,
- text: noteText,
+ text: cleaned,
isAiGenerated: true,
segmentStartOrder: segmentStartOrder,
segmentEndOrder: segmentEndOrder
@@ -309,8 +284,7 @@ class LiveNotesMonitor: ObservableObject {
if let note = record.toLiveNote() {
await MainActor.run {
self.notes.append(note)
- self.existingNotesContext.append(noteText)
- // Trim context to prevent unbounded growth (keep most recent notes)
+ self.existingNotesContext.append(cleaned)
if self.existingNotesContext.count > self.maxExistingNotesContext {
self.existingNotesContext.removeFirst(self.existingNotesContext.count - self.maxExistingNotesContext)
}
@@ -321,7 +295,6 @@ class LiveNotesMonitor: ObservableObject {
await MainActor.run { self.isGenerating = false }
}
} catch let dbError as DatabaseError where dbError.resultCode == .SQLITE_CONSTRAINT {
- // Session was deleted during async AI generation — not an error
log("LiveNotesMonitor: Session \(sessionId) deleted during note generation, skipping")
await MainActor.run { self.isGenerating = false }
} catch {
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift
index 87ca8c4b0a..e1c2701817 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift
@@ -1,4 +1,3 @@
-import AppKit
import Foundation
import GRDB
@@ -19,7 +18,7 @@ actor AdviceAssistant: ProactiveAssistant {
// MARK: - Properties
- private let geminiClient: GeminiClient
+ private let backendService: BackendProactiveService
private var isRunning = false
private var lastAnalysisTime: Date = .distantPast
private var previousAdvice: [ExtractedAdvice] = [] // Dedup window for advice context
@@ -33,15 +32,6 @@ actor AdviceAssistant: ProactiveAssistant {
private let frameSignal: AsyncStream
private let frameSignalContinuation: AsyncStream.Continuation
- /// Get the current system prompt from settings (accessed on MainActor for thread safety)
- private var systemPrompt: String {
- get async {
- await MainActor.run {
- AdviceAssistantSettings.shared.analysisPrompt
- }
- }
- }
-
/// Get the extraction interval from settings
private var extractionInterval: TimeInterval {
get async {
@@ -62,9 +52,8 @@ actor AdviceAssistant: ProactiveAssistant {
// MARK: - Initialization
- init(apiKey: String? = nil) throws {
- // Use Gemini 3.1 Pro for better advice quality (3-pro-preview retires March 9, 2026)
- self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest")
+ init(backendService: BackendProactiveService) {
+ self.backendService = backendService
let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1))
self.frameSignal = stream
@@ -140,25 +129,6 @@ actor AdviceAssistant: ProactiveAssistant {
log("Advice assistant stopped")
}
- // MARK: - Test Analysis (for test runner)
-
- /// Run the extraction pipeline on arbitrary JPEG data without side effects (no saving, no events).
- /// Used by the test runner to replay past screenshots.
- /// `screenshotTime` anchors the activity summary to the screenshot's actual timestamp.
- /// Returns (result, sqlQueryCount) where sqlQueryCount is the number of execute_sql tool calls made.
- func testAnalyze(jpegData: Data, appName: String, windowTitle: String? = nil, screenshotTime: Date) async throws -> (AdviceExtractionResult?, Int) {
- let interval = await extractionInterval
- let lookbackStart = screenshotTime.addingTimeInterval(-interval)
- return try await runAdviceExtraction(
- jpegData: nil,
- appName: appName,
- windowTitle: windowTitle,
- referenceTime: screenshotTime,
- lookbackStart: lookbackStart,
- trackSqlCount: true
- )
- }
-
// MARK: - ProactiveAssistant Protocol Methods
func shouldAnalyze(frameNumber: Int, timeSinceLastAnalysis: TimeInterval) -> Bool {
@@ -379,62 +349,34 @@ actor AdviceAssistant: ProactiveAssistant {
pendingFrame = nil
}
- // MARK: - Image Processing
-
- /// Resize and compress an image for Gemini analysis (max 1280px wide, JPEG quality 0.4)
- private static func compressForGemini(_ data: Data) -> Data? {
- guard let source = CGImageSourceCreateWithData(data as CFData, nil),
- let cgImage = CGImageSourceCreateImageAtIndex(source, 0, nil) else { return nil }
-
- let maxWidth = 1280
- let width = cgImage.width
- let height = cgImage.height
- let scale = width > maxWidth ? Double(maxWidth) / Double(width) : 1.0
- let newWidth = Int(Double(width) * scale)
- let newHeight = Int(Double(height) * scale)
-
- guard let context = CGContext(
- data: nil, width: newWidth, height: newHeight,
- bitsPerComponent: 8, bytesPerRow: 0,
- space: CGColorSpaceCreateDeviceRGB(),
- bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
- ) else { return nil }
-
- context.interpolationQuality = .high
- context.draw(cgImage, in: CGRect(x: 0, y: 0, width: newWidth, height: newHeight))
-
- guard let resized = context.makeImage() else { return nil }
-
- let mutableData = NSMutableData()
- guard let dest = CGImageDestinationCreateWithData(mutableData as CFMutableData, "public.jpeg" as CFString, 1, nil) else { return nil }
- CGImageDestinationAddImage(dest, resized, [kCGImageDestinationLossyCompressionQuality: 0.4] as CFDictionary)
- guard CGImageDestinationFinalize(dest) else { return nil }
- return mutableData as Data
- }
-
- // MARK: - Helpers
+ // MARK: - Test Analysis (for test runner)
- /// Get user's preferred language, cached for 1 hour
- private func getUserLanguage() async -> String? {
- // Return cached value if fresh (< 1 hour)
- if let cached = cachedLanguage, Date().timeIntervalSince(languageFetchedAt) < 3600 {
- return cached
+ /// Run extraction via backend for test runner. Returns (result, 0) for compatibility.
+ func testAnalyze(jpegData: Data, appName: String, windowTitle: String? = nil, screenshotTime: Date) async throws -> (AdviceExtractionResult?, Int) {
+ let base64 = autoreleasepool { jpegData.base64EncodedString() }
+ let backendResult = try await backendService.generateAdvice(
+ imageBase64: base64, appName: appName, windowTitle: windowTitle ?? ""
+ )
+ guard let adviceDict = backendResult.advice as? [String: Any] else {
+ return (AdviceExtractionResult(hasAdvice: false, advice: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0)
}
-
- do {
- let response = try await APIClient.shared.getUserLanguage()
- let lang = response.language
- cachedLanguage = lang
- languageFetchedAt = Date()
- return lang.isEmpty ? nil : lang
- } catch {
- // Fall back to transcription language setting
- let fallback = await MainActor.run { AssistantSettings.shared.transcriptionLanguage }
- return fallback.isEmpty || fallback == "en" ? nil : fallback
+ let hasAdvice = adviceDict["has_advice"] as? Bool ?? !adviceDict.isEmpty
+ guard hasAdvice, let adviceText = adviceDict["content"] as? String ?? adviceDict["advice"] as? String, !adviceText.isEmpty else {
+ return (AdviceExtractionResult(hasAdvice: false, advice: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0)
}
+ let categoryStr = adviceDict["category"] as? String ?? "other"
+ let category = AdviceCategory(rawValue: categoryStr) ?? .other
+ let confidence = adviceDict["confidence"] as? Double ?? 0.5
+ let advice = ExtractedAdvice(
+ advice: adviceText, headline: adviceDict["headline"] as? String,
+ reasoning: adviceDict["reasoning"] as? String, category: category,
+ sourceApp: appName, confidence: confidence
+ )
+ let result = AdviceExtractionResult(hasAdvice: true, advice: advice, contextSummary: "Analyzed \(appName)", currentActivity: "")
+ return (result, 0)
}
- // MARK: - Analysis
+ // MARK: - Backend Analysis (Phase 2 thin client)
private func processFrame(_ frame: CapturedFrame) async {
guard await isEnabled else { return }
@@ -443,7 +385,6 @@ actor AdviceAssistant: ProactiveAssistant {
return
}
- // Handle the result with screenshot ID for SQLite storage
await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, windowTitle: frame.windowTitle) { type, data in
Task { @MainActor in
AssistantCoordinator.shared.sendEvent(type: type, data: data)
@@ -455,549 +396,68 @@ actor AdviceAssistant: ProactiveAssistant {
}
private func extractAdvice(from frame: CapturedFrame) async throws -> AdviceExtractionResult? {
- let now = Date()
- // Cap lookback: since last analysis or max 1 hour ago
- let lookbackStart = max(lastAnalysisTime, now.addingTimeInterval(-3600))
- let (result, _) = try await runAdviceExtraction(
- jpegData: nil,
+ let base64 = autoreleasepool { frame.jpegData.base64EncodedString() }
+ let backendResult = try await backendService.generateAdvice(
+ imageBase64: base64,
appName: frame.appName,
- windowTitle: frame.windowTitle,
- referenceTime: now,
- lookbackStart: lookbackStart,
- trackSqlCount: false
+ windowTitle: frame.windowTitle ?? ""
)
- return result
- }
- // MARK: - Core Extraction (shared by production + test)
-
- /// Two-phase advice extraction:
- /// Phase 1 (text-only): Activity summary + SQL investigation loop. Model investigates via
- /// execute_sql, then calls `request_screenshot` with an ID and its findings so far.
- /// Phase 2 (single vision call): Load the chosen screenshot + Phase 1 findings → single
- /// Gemini call with image → provide_advice or no_advice.
- /// Returns (result, sqlQueryCount).
- private func runAdviceExtraction(
- jpegData: Data?,
- appName: String,
- windowTitle: String?,
- referenceTime: Date,
- lookbackStart: Date,
- trackSqlCount: Bool
- ) async throws -> (AdviceExtractionResult?, Int) {
- var sqlCount = 0
-
- // Build prompt with current context
- let timeFormatter = DateFormatter()
- timeFormatter.dateFormat = "h:mm a, EEEE"
- var prompt = "CURRENT APP: \(appName)."
- if let windowTitle = windowTitle, !windowTitle.isEmpty {
- prompt += " Window: \"\(windowTitle)\"."
- }
- prompt += " Time: \(timeFormatter.string(from: referenceTime))."
-
- // Add activity summary from database, anchored to the reference time
- let elapsed = referenceTime.timeIntervalSince(lookbackStart)
- log("Advice: Activity lookback: \(String(format: "%.0f", elapsed))s (\(lookbackStart) to \(referenceTime))")
- let activitySummary = await buildActivitySummary(from: lookbackStart, to: referenceTime)
- if !activitySummary.isEmpty {
- prompt += "\n\n" + activitySummary
- log("Advice: --- ACTIVITY SUMMARY ---\n\(activitySummary)")
- } else {
- log("Advice: --- ACTIVITY SUMMARY --- (empty, no screenshots in range)")
- }
-
- // Add user profile for context
- if let profile = await AIUserProfileService.shared.getLatestProfile() {
- prompt += "\n\nUSER PROFILE (who this user is):\n"
- prompt += profile.profileText + "\n"
- }
-
- // Add previous advice for dedup
- if !previousAdvice.isEmpty {
- prompt += "\n\nPREVIOUSLY PROVIDED ADVICE (do not repeat these or semantically similar):\n"
- let adviceToInclude = previousAdvice.prefix(maxAdviceInPrompt)
- for (index, advice) in adviceToInclude.enumerated() {
- prompt += "\(index + 1). \(advice.advice)"
- if let reasoning = advice.reasoning {
- prompt += " (Reasoning: \(reasoning))"
- }
- prompt += "\n"
- }
- prompt += "\nOnly provide advice if there's a genuinely NEW non-obvious insight not covered above."
- } else {
- prompt += "\n\nOnly provide advice if there's something specific and non-obvious that would help."
- }
-
- prompt += "\n\nInvestigate the activity summary. Scan OCR from the TOP 3-5 apps (not just the dominant one) — the best insights often come from browsers, communication apps, and notes, not just the app with the most screenshots. Skip apps with < 10 screenshots. When you've identified the most interesting screenshot, call request_screenshot with the ID and your findings. Or call no_advice if nothing qualifies."
-
- log("Advice: --- PROMPT ---\n\(prompt)")
-
- // Build system prompt
- var currentSystemPrompt = await systemPrompt
- if let language = await getUserLanguage(), language != "en" {
- currentSystemPrompt += "\n\nIMPORTANT: Respond in the user's preferred language: \(language)"
- }
- currentSystemPrompt += "\n\nDATABASE SCHEMA for execute_sql:\nscreenshots table columns: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT"
-
- // =============================================
- // PHASE 1: Text-only investigation loop
- // =============================================
-
- let phase1Tools = buildPhase1Tools()
- var contents: [GeminiImageToolRequest.Content] = [
- GeminiImageToolRequest.Content(
- role: "user",
- parts: [GeminiImageToolRequest.Part(text: prompt)]
+ // Parse backend response into AdviceExtractionResult
+ guard let adviceDict = backendResult.advice as? [String: Any] else {
+ return AdviceExtractionResult(
+ hasAdvice: false,
+ advice: nil,
+ contextSummary: "Analyzed \(frame.appName)",
+ currentActivity: ""
)
- ]
-
- let client = self.geminiClient
- var chosenScreenshotId: Int64?
- var investigationFindings: String?
-
- for iteration in 0..<7 {
- let iterContents = contents
- let iterSystemPrompt = currentSystemPrompt
- let iterTools = [phase1Tools]
- let iterForce = iteration == 0
- let result: ToolChatResult
- do {
- result = try await withThrowingTimeout(seconds: 120) {
- try await client.sendImageToolLoop(
- contents: iterContents,
- systemPrompt: iterSystemPrompt,
- tools: iterTools,
- forceToolCall: iterForce
- )
- }
- } catch {
- log("Advice: Phase 1 failed on iteration \(iteration): \(error.localizedDescription)")
- throw error
- }
-
- guard let toolCall = result.toolCalls.first else {
- log("Advice: Phase 1 — no tool call on iteration \(iteration), breaking")
- break
- }
-
- switch toolCall.name {
- case "execute_sql":
- let query = toolCall.arguments["query"] as? String ?? ""
- sqlCount += 1
- log("Advice: P1 execute_sql iter \(iteration): \(query)")
- let sqlToolCall = ToolCall(name: "execute_sql", arguments: ["query": query], thoughtSignature: nil)
- let resultStr = await ChatToolExecutor.execute(sqlToolCall)
- let truncated = resultStr.count > 2000 ? String(resultStr.prefix(2000)) + "... (truncated)" : resultStr
- log("Advice: P1 sql result (\(resultStr.count) chars): \(truncated)")
-
- contents.append(GeminiImageToolRequest.Content(
- role: "model",
- parts: [GeminiImageToolRequest.Part(
- functionCall: .init(name: toolCall.name, args: ["query": query]),
- thoughtSignature: toolCall.thoughtSignature
- )]
- ))
- contents.append(GeminiImageToolRequest.Content(
- role: "user",
- parts: [GeminiImageToolRequest.Part(functionResponse: .init(
- name: toolCall.name,
- response: .init(result: resultStr)
- ))]
- ))
- continue
-
- case "request_screenshot":
- let findings = toolCall.arguments["findings"] as? String ?? ""
- investigationFindings = findings
- if let idInt = toolCall.arguments["screenshot_id"] as? Int {
- chosenScreenshotId = Int64(idInt)
- } else if let idInt64 = toolCall.arguments["screenshot_id"] as? Int64 {
- chosenScreenshotId = idInt64
- } else if let idStr = toolCall.arguments["screenshot_id"] as? String, let parsed = Int64(idStr) {
- chosenScreenshotId = parsed
- } else if let idDouble = toolCall.arguments["screenshot_id"] as? Double {
- chosenScreenshotId = Int64(idDouble)
- }
- log("Advice: P1 request_screenshot iter \(iteration): id=\(chosenScreenshotId ?? 0), findings=\(findings.prefix(200))")
- break // Exit phase 1
-
- case "no_advice":
- let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No context"
- let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown"
- log("Advice: P1 no_advice — \(contextSummary)")
- return (AdviceExtractionResult(
- hasAdvice: false,
- advice: nil,
- contextSummary: contextSummary,
- currentActivity: currentActivity
- ), sqlCount)
-
- default:
- log("Advice: P1 unknown tool: \(toolCall.name), breaking")
- break
- }
-
- // Break out of loop if request_screenshot was called
- if chosenScreenshotId != nil { break }
- }
-
- // If Phase 1 exhausted without choosing a screenshot, no advice
- guard let screenshotId = chosenScreenshotId, let findings = investigationFindings else {
- log("Advice: Phase 1 exhausted without request_screenshot")
- return (nil, sqlCount)
}
- // =============================================
- // PHASE 2: Single vision call with chosen screenshot
- // =============================================
-
- log("Advice: Phase 2 — loading screenshot \(screenshotId)")
-
- // Load the screenshot image
- let imageData: Data
- do {
- guard let screenshot = try await RewindDatabase.shared.getScreenshot(id: screenshotId) else {
- log("Advice: P2 screenshot not in DB: \(screenshotId)")
- return (nil, sqlCount)
- }
- // Check active chunk
- if screenshot.usesVideoStorage, let chunk = screenshot.videoChunkPath {
- let activeChunk = await VideoChunkEncoder.shared.currentChunkPath
- if chunk == activeChunk {
- log("Advice: P2 screenshot is in active chunk, skipping")
- return (nil, sqlCount)
- }
- }
- let rawData = try await RewindStorage.shared.loadScreenshotData(for: screenshot)
- imageData = Self.compressForGemini(rawData) ?? rawData
- log("Advice: P2 loaded \(imageData.count) bytes (\(rawData.count) raw) from \(screenshot.appName)")
- } catch {
- log("Advice: P2 screenshot load failed: \(error.localizedDescription)")
- return (nil, sqlCount)
- }
-
- // Build Phase 2 prompt — compact findings + image + cross-reference instruction
- let phase2Prompt = """
- INVESTIGATION FINDINGS:
- \(findings)
-
- The screenshot below is from the app/window identified during investigation.
-
- Before giving advice, CROSS-REFERENCE your findings:
- - Use execute_sql to check if this issue was resolved in later screenshots
- - Check if the user moved on to something else (the issue may be stale)
- - Verify the context is still relevant by looking at nearby timestamps
-
- Then call provide_advice if the insight is still valid, or no_advice if it was resolved or is no longer relevant.
- """
-
- let phase2Tools = buildPhase2Tools()
- let base64 = imageData.base64EncodedString()
- var phase2Contents: [GeminiImageToolRequest.Content] = [
- GeminiImageToolRequest.Content(
- role: "user",
- parts: [
- GeminiImageToolRequest.Part(text: phase2Prompt),
- GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64),
- ]
+ let hasAdvice = adviceDict["has_advice"] as? Bool ?? !adviceDict.isEmpty
+ guard hasAdvice else {
+ return AdviceExtractionResult(
+ hasAdvice: false,
+ advice: nil,
+ contextSummary: "Analyzed \(frame.appName)",
+ currentActivity: ""
)
- ]
-
- // Phase 2 loop — model can cross-reference via SQL before deciding
- for p2Iteration in 0..<5 {
- let p2Contents = phase2Contents
- let p2SystemPrompt = currentSystemPrompt
- let p2Tools = [phase2Tools]
- let p2Force = p2Iteration == 0
- let phase2Result: ToolChatResult
- do {
- phase2Result = try await withThrowingTimeout(seconds: 120) {
- try await client.sendImageToolLoop(
- contents: p2Contents,
- systemPrompt: p2SystemPrompt,
- tools: p2Tools,
- forceToolCall: p2Force
- )
- }
- } catch {
- log("Advice: Phase 2 failed on iteration \(p2Iteration): \(error.localizedDescription)")
- throw error
- }
-
- guard let toolCall = phase2Result.toolCalls.first else {
- log("Advice: Phase 2 — no tool call on iteration \(p2Iteration), breaking")
- break
- }
-
- switch toolCall.name {
- case "execute_sql":
- let query = toolCall.arguments["query"] as? String ?? ""
- sqlCount += 1
- log("Advice: P2 execute_sql iter \(p2Iteration): \(query)")
- let sqlToolCall = ToolCall(name: "execute_sql", arguments: ["query": query], thoughtSignature: nil)
- let resultStr = await ChatToolExecutor.execute(sqlToolCall)
- let truncated = resultStr.count > 2000 ? String(resultStr.prefix(2000)) + "... (truncated)" : resultStr
- log("Advice: P2 sql result (\(resultStr.count) chars): \(truncated)")
-
- phase2Contents.append(GeminiImageToolRequest.Content(
- role: "model",
- parts: [GeminiImageToolRequest.Part(
- functionCall: .init(name: toolCall.name, args: ["query": query]),
- thoughtSignature: toolCall.thoughtSignature
- )]
- ))
- phase2Contents.append(GeminiImageToolRequest.Content(
- role: "user",
- parts: [GeminiImageToolRequest.Part(functionResponse: .init(
- name: toolCall.name,
- response: .init(result: resultStr)
- ))]
- ))
- continue
-
- case "provide_advice":
- log("Advice: P2 provide_advice (after \(p2Iteration) cross-reference iterations)")
- return (parseProvideAdvice(toolCall), sqlCount)
-
- case "no_advice":
- let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No context"
- let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown"
- log("Advice: P2 no_advice — \(contextSummary)")
- return (AdviceExtractionResult(
- hasAdvice: false,
- advice: nil,
- contextSummary: contextSummary,
- currentActivity: currentActivity
- ), sqlCount)
-
- default:
- log("Advice: P2 unexpected tool: \(toolCall.name)")
- break
- }
- break // Break on unexpected tool
- }
- return (nil, sqlCount)
- }
-
- // MARK: - Activity Summary
-
- /// Query the screenshots table to build a summary of recent activity.
- /// - `from`: lower bound (e.g. last analysis time or screenshot.timestamp - interval)
- /// - `to`: upper bound (e.g. now or the screenshot's timestamp)
- private func buildActivitySummary(from lookbackStart: Date, to referenceTime: Date) async -> String {
- guard let dbQueue = await RewindDatabase.shared.getDatabaseQueue() else {
- return ""
}
- do {
- return try await dbQueue.read { db in
- // Pass Date objects directly — GRDB encodes them as UTC strings
- // matching the stored format. Manual DateFormatter uses local timezone
- // which causes mismatches.
- let rows = try Row.fetchAll(db, sql: """
- SELECT appName, windowTitle, COUNT(*) as count,
- MIN(timestamp) as first_seen, MAX(timestamp) as last_seen
- FROM screenshots
- WHERE timestamp >= ? AND timestamp <= ?
- AND appName IS NOT NULL AND appName != ''
- GROUP BY appName, windowTitle
- ORDER BY count DESC
- LIMIT 30
- """, arguments: [lookbackStart, referenceTime])
-
- if rows.isEmpty {
- return ""
- }
-
- let totalScreenshots = rows.reduce(0) { $0 + (($1["count"] as? Int64).map(Int.init) ?? ($1["count"] as? Int) ?? 0) }
- let elapsedMin = referenceTime.timeIntervalSince(lookbackStart) / 60.0
-
- let timeOnlyFormatter = DateFormatter()
- timeOnlyFormatter.dateFormat = "HH:mm:ss"
-
- var lines: [String] = []
- lines.append("ACTIVITY SUMMARY (last \(Int(elapsedMin)) min, \(totalScreenshots) screenshots):")
- lines.append("Time range: \(timeOnlyFormatter.string(from: lookbackStart)) – \(timeOnlyFormatter.string(from: referenceTime))")
- lines.append("")
- lines.append("App | Window | Screenshots | Est. Duration")
- lines.append(String(repeating: "-", count: 60))
-
- for row in rows {
- let app = row["appName"] as? String ?? "Unknown"
- let window = row["windowTitle"] as? String ?? ""
- let count = (row["count"] as? Int64).map(Int.init) ?? (row["count"] as? Int) ?? 0
- let estMinutes = String(format: "%.1f", Double(count) / 60.0)
- let windowDisplay = window.isEmpty ? "(no title)" : String(window.prefix(50))
- lines.append("\(app) | \(windowDisplay) | \(count) | \(estMinutes) min")
- }
-
- let summary = lines.joined(separator: "\n")
- log("Advice: Activity summary (last \(Int(elapsedMin)) min, \(totalScreenshots) screenshots)")
- return summary
- }
- } catch {
- logError("Advice: Failed to build activity summary", error: error)
- return ""
+ let adviceText = adviceDict["content"] as? String ?? adviceDict["advice"] as? String ?? ""
+ guard !adviceText.isEmpty else {
+ return AdviceExtractionResult(
+ hasAdvice: false,
+ advice: nil,
+ contextSummary: "Analyzed \(frame.appName)",
+ currentActivity: ""
+ )
}
- }
- // MARK: - Tool Definitions
-
- /// Phase 1 tools: text-only investigation (execute_sql, request_screenshot, no_advice)
- private func buildPhase1Tools() -> GeminiTool {
- GeminiTool(functionDeclarations: [
- GeminiTool.FunctionDeclaration(
- name: "execute_sql",
- description: "Execute a SQL query on the local database to investigate screen activity. The screenshots table has: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT. Use this to read OCR text from interesting windows, check what the user was doing, etc. SELECT queries only. Auto-limited to 200 rows.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "query": .init(type: "string", description: "SQL SELECT query to execute on the screenshots table")
- ],
- required: ["query"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "request_screenshot",
- description: "Request to view a specific screenshot. Call this when you've found something interesting via SQL and want to see the actual screen. Provide the screenshot ID and a summary of your findings so far. The screenshot will be shown to you for final analysis.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "screenshot_id": .init(type: "integer", description: "The screenshot ID from the screenshots table"),
- "findings": .init(type: "string", description: "Summary of what you found during investigation — what app, what OCR text caught your attention, and what you suspect might be worth advising about")
- ],
- required: ["screenshot_id", "findings"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "no_advice",
- description: "Call this when there is nothing worth advising about. Nothing qualifies as a specific, non-obvious insight. This ends the analysis.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"),
- "current_activity": .init(type: "string", description: "High-level description of user's activity")
- ],
- required: ["context_summary", "current_activity"]
- )
- ),
- ])
- }
-
- /// Phase 2 tools: vision call with screenshot + SQL cross-referencing (execute_sql, provide_advice, no_advice)
- private func buildPhase2Tools() -> GeminiTool {
- GeminiTool(functionDeclarations: [
- GeminiTool.FunctionDeclaration(
- name: "execute_sql",
- description: "Cross-reference your findings by querying the database. Use this to check if an issue was resolved in later screenshots, verify context across time, or look up related activity. The screenshots table has: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT. SELECT queries only.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "query": .init(type: "string", description: "SQL SELECT query to execute on the screenshots table")
- ],
- required: ["query"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "provide_advice",
- description: "Call this when you have a specific, non-obvious insight for the user based on the screenshot and your investigation findings. You should cross-reference first using execute_sql to verify the issue is still relevant.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "advice": .init(type: "string", description: "The advice text (1-2 sentences, max 100 chars). Start with what you noticed, then why it matters."),
- "headline": .init(type: "string", description: "Ultra-short observation (max 5 words) for notification preview. E.g. 'Draft saved in /tmp', 'Credentials visible in terminal'"),
- "reasoning": .init(type: "string", description: "Brief explanation of why this advice is relevant"),
- "category": .init(type: "string", description: "Category of advice", enumValues: ["productivity", "communication", "learning", "other"]),
- "source_app": .init(type: "string", description: "App where context was observed"),
- "confidence": .init(type: "number", description: "Confidence score 0.0-1.0. 0.90+: preventing clear mistake. 0.75-0.89: highly relevant non-obvious tip. 0.60-0.74: useful but user might know."),
- "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"),
- "current_activity": .init(type: "string", description: "High-level description of user's activity")
- ],
- required: ["advice", "headline", "category", "source_app", "confidence", "context_summary", "current_activity"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "no_advice",
- description: "Call this when the screenshot doesn't reveal anything worth advising about, or when cross-referencing shows the issue was already resolved.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"),
- "current_activity": .init(type: "string", description: "High-level description of user's activity")
- ],
- required: ["context_summary", "current_activity"]
- )
- ),
- ])
- }
-
- // MARK: - Parse Tool Results
-
- /// Parse the provide_advice tool call into an AdviceExtractionResult
- private func parseProvideAdvice(_ toolCall: ToolCall) -> AdviceExtractionResult {
- let adviceText = toolCall.arguments["advice"] as? String ?? ""
- let headline = toolCall.arguments["headline"] as? String
- let reasoning = toolCall.arguments["reasoning"] as? String
- let categoryStr = toolCall.arguments["category"] as? String ?? "other"
+ let categoryStr = adviceDict["category"] as? String ?? "other"
let category = AdviceCategory(rawValue: categoryStr) ?? .other
- let sourceApp = toolCall.arguments["source_app"] as? String ?? ""
- let contextSummary = toolCall.arguments["context_summary"] as? String ?? ""
- let currentActivity = toolCall.arguments["current_activity"] as? String ?? ""
-
let confidence: Double
- if let confValue = toolCall.arguments["confidence"] as? Double {
+ if let confValue = adviceDict["confidence"] as? Double {
confidence = confValue
- } else if let confInt = toolCall.arguments["confidence"] as? Int {
+ } else if let confInt = adviceDict["confidence"] as? Int {
confidence = Double(confInt)
- } else if let confStr = toolCall.arguments["confidence"] as? String, let parsed = Double(confStr) {
- confidence = parsed
} else {
confidence = 0.5
}
let advice = ExtractedAdvice(
advice: adviceText,
- headline: headline,
- reasoning: reasoning,
+ headline: adviceDict["headline"] as? String,
+ reasoning: adviceDict["reasoning"] as? String,
category: category,
- sourceApp: sourceApp,
+ sourceApp: frame.appName,
confidence: confidence
)
- log("Advice: --- PROVIDE_ADVICE ---")
- log("Advice: advice: \(adviceText)")
- log("Advice: headline: \(headline ?? "(none)")")
- log("Advice: reasoning: \(reasoning ?? "(none)")")
- log("Advice: category: \(categoryStr)")
- log("Advice: source_app: \(sourceApp)")
- log("Advice: confidence: \(confidence)")
- log("Advice: context: \(contextSummary)")
- log("Advice: activity: \(currentActivity)")
return AdviceExtractionResult(
hasAdvice: true,
advice: advice,
- contextSummary: contextSummary,
- currentActivity: currentActivity
+ contextSummary: adviceDict["context_summary"] as? String ?? "Analyzed \(frame.appName)",
+ currentActivity: adviceDict["current_activity"] as? String ?? ""
)
}
}
-
-// MARK: - Timeout Helper
-
-/// Run an async operation with a timeout. Throws `CancellationError` if the timeout expires.
-private func withThrowingTimeout(seconds: Double, operation: @escaping @Sendable () async throws -> T) async throws -> T {
- try await withThrowingTaskGroup(of: T.self) { group in
- group.addTask {
- try await operation()
- }
- group.addTask {
- try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
- throw CancellationError()
- }
- // First task to complete wins; cancel the other
- let result = try await group.next()!
- group.cancelAll()
- return result
- }
-}
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift
index 33a355f56b..88eaddfc7c 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift
@@ -17,7 +17,7 @@ actor FocusAssistant: ProactiveAssistant {
// MARK: - Properties
- private let geminiClient: GeminiClient
+ private let backendService: BackendProactiveService
private let onAlert: (String) -> Void
private let onStatusChange: ((FocusStatus) -> Void)?
private let onRefocus: (() -> Void)?
@@ -35,12 +35,6 @@ actor FocusAssistant: ProactiveAssistant {
private let maxPendingTasks = 3
private var currentApp: String?
- // MARK: - Context Cache
- // Cached context from local DB (goals, tasks, memories) to enrich focus analysis
- private var cachedContextString: String?
- private var contextCacheTime: Date?
- private let contextCacheDuration: TimeInterval = 120 // 2 minutes
-
// MARK: - Smart Analysis Filtering
// Skip analysis when user is focused on the same context (app + window title)
// Also skip during cooldown period after distraction (unless context changes)
@@ -58,25 +52,16 @@ actor FocusAssistant: ProactiveAssistant {
private var consecutiveErrorCount = 0
private var errorBackoffEndTime: Date?
- /// Get the current system prompt from settings (accessed on MainActor for thread safety)
- private var systemPrompt: String {
- get async {
- await MainActor.run {
- FocusAssistantSettings.shared.analysisPrompt
- }
- }
- }
-
// MARK: - Initialization
init(
- apiKey: String? = nil,
+ backendService: BackendProactiveService,
onAlert: @escaping (String) -> Void = { _ in },
onStatusChange: ((FocusStatus) -> Void)? = nil,
onRefocus: (() -> Void)? = nil,
onDistraction: (() -> Void)? = nil
- ) throws {
- self.geminiClient = try GeminiClient(apiKey: apiKey)
+ ) {
+ self.backendService = backendService
self.onAlert = onAlert
self.onStatusChange = onStatusChange
self.onRefocus = onRefocus
@@ -299,8 +284,6 @@ actor FocusAssistant: ProactiveAssistant {
analysisCooldownEndTime = nil
consecutiveErrorCount = 0
errorBackoffEndTime = nil
- cachedContextString = nil
- contextCacheTime = nil
// Clear cooldown in UI
await MainActor.run {
@@ -353,99 +336,26 @@ actor FocusAssistant: ProactiveAssistant {
/// Run analysis on a screenshot with no side effects (no saving, no state updates, no notifications).
/// Used by the test runner GUI and CLI.
func testAnalyze(jpegData: Data, appName: String) async throws -> ScreenAnalysis? {
- return try await analyzeScreenshot(jpegData: jpegData)
+ return try await analyzeScreenshot(jpegData: jpegData, appName: appName, windowTitle: nil)
}
/// Reset test history — call before starting a test run to get a clean slate.
func resetTestHistory() {
- testAnalysisHistory.removeAll()
+ // History is now tracked server-side; no-op on client
}
/// Run analysis with accumulating history across calls (simulates production behavior).
- /// Each result is appended to a separate test history buffer so the model sees prior decisions.
+ /// History is tracked server-side per WebSocket session, so this is equivalent to testAnalyze.
func testAnalyzeWithHistory(jpegData: Data, appName: String) async throws -> ScreenAnalysis? {
- let result = try await analyzeScreenshotWithHistory(jpegData: jpegData, history: testAnalysisHistory)
- if let result = result {
- testAnalysisHistory.append(result)
- if testAnalysisHistory.count > maxHistorySize {
- testAnalysisHistory.removeFirst()
- }
- }
- return result
- }
-
- /// Separate history buffer for test runs (doesn't pollute production history)
- private var testAnalysisHistory: [ScreenAnalysis] = []
-
- /// Variant of analyzeScreenshot that accepts an explicit history array
- private func analyzeScreenshotWithHistory(jpegData: Data, history: [ScreenAnalysis]) async throws -> ScreenAnalysis? {
- let context = await refreshContext()
-
- // Format provided history
- var historyText = ""
- if !history.isEmpty {
- var lines = ["Recent activity (oldest to newest):"]
- for (i, past) in history.enumerated() {
- lines.append("\(i + 1). [\(past.status.rawValue)] \(past.appOrSite): \(past.description)")
- if let message = past.message {
- lines.append(" Message: \(message)")
- }
- }
- historyText = lines.joined(separator: "\n")
- }
-
- var promptParts: [String] = []
- if !context.isEmpty {
- promptParts.append(context)
- }
- if !historyText.isEmpty {
- promptParts.append(historyText)
- }
- promptParts.append("Now analyze this new screenshot:")
- let prompt = promptParts.joined(separator: "\n\n")
-
- let currentSystemPrompt = await systemPrompt
-
- let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema(
- type: "object",
- properties: [
- "status": .init(type: "string", enum: ["focused", "distracted"], description: "Whether the user is focused or distracted"),
- "app_or_site": .init(type: "string", enum: nil, description: "The app or website visible"),
- "description": .init(type: "string", enum: nil, description: "Brief description of what's on screen"),
- "message": .init(type: "string", enum: nil, description: "Coaching message")
- ],
- required: ["status", "app_or_site", "description"]
- )
-
- let responseText = try await geminiClient.sendRequest(
- prompt: prompt,
- imageData: jpegData,
- systemPrompt: currentSystemPrompt,
- responseSchema: responseSchema
- )
-
- return try JSONDecoder().decode(ScreenAnalysis.self, from: Data(responseText.utf8))
+ return try await analyzeScreenshot(jpegData: jpegData, appName: appName, windowTitle: nil)
}
// MARK: - Analysis
- private func formatHistory() -> String {
- guard !analysisHistory.isEmpty else { return "" }
-
- var lines = ["Recent activity (oldest to newest):"]
- for (i, past) in analysisHistory.enumerated() {
- lines.append("\(i + 1). [\(past.status.rawValue)] \(past.appOrSite): \(past.description)")
- if let message = past.message {
- lines.append(" Message: \(message)")
- }
- }
- return lines.joined(separator: "\n")
- }
-
private func processFrame(_ frame: CapturedFrame) async {
guard await isEnabled else { return }
do {
- guard let analysis = try await analyzeScreenshot(jpegData: frame.jpegData) else {
+ guard let analysis = try await analyzeScreenshot(jpegData: frame.jpegData, appName: frame.appName, windowTitle: frame.windowTitle) else {
return
}
@@ -585,118 +495,14 @@ actor FocusAssistant: ProactiveAssistant {
}
}
- /// Refresh context from local DB (goals, tasks, memories) with caching
- private func refreshContext() async -> String {
- // Return cached context if fresh
- if let cached = cachedContextString,
- let cacheTime = contextCacheTime,
- Date().timeIntervalSince(cacheTime) < contextCacheDuration {
- return cached
- }
-
- var sections: [String] = []
-
- // AI User Profile
- do {
- if let profile = await AIUserProfileService.shared.getLatestProfile() {
- sections.append("USER PROFILE (who this user is):\n\(profile.profileText)")
- }
- }
-
- // Time context
- let formatter = DateFormatter()
- formatter.dateFormat = "EEEE, MMMM d, yyyy 'at' h:mm a"
- sections.append("TIME CONTEXT:\n\(formatter.string(from: Date()))")
-
- // Active goals
- do {
- let goals = try await GoalStorage.shared.getLocalGoals(activeOnly: true)
- if !goals.isEmpty {
- var lines = ["ACTIVE GOALS:"]
- for (i, goal) in goals.prefix(10).enumerated() {
- let desc = goal.description.map { " - \($0)" } ?? ""
- lines.append("\(i + 1). \(goal.title)\(desc)")
- }
- sections.append(lines.joined(separator: "\n"))
- }
- } catch {
- logError("Focus: Failed to load goals for context", error: error)
- }
-
- // Top tasks by importance
- do {
- let tasks = try await ActionItemStorage.shared.getTopRelevanceTasks(limit: 50)
- if !tasks.isEmpty {
- var lines = ["CURRENT TASKS (by importance):"]
- for (i, task) in tasks.enumerated() {
- let priority = task.priority ?? "medium"
- lines.append("\(i + 1). [\(priority)] \(task.description)")
- }
- sections.append(lines.joined(separator: "\n"))
- }
- } catch {
- logError("Focus: Failed to load tasks for context", error: error)
- }
-
- // Recent memories
- do {
- let memories = try await MemoryStorage.shared.getLocalMemories(limit: 50, category: "core")
- if !memories.isEmpty {
- var lines = ["RECENT MEMORIES:"]
- for (i, memory) in memories.enumerated() {
- lines.append("\(i + 1). \(memory.content)")
- }
- sections.append(lines.joined(separator: "\n"))
- }
- } catch {
- logError("Focus: Failed to load memories for context", error: error)
- }
-
- let contextString = sections.joined(separator: "\n\n")
- cachedContextString = contextString
- contextCacheTime = Date()
- return contextString
- }
-
- private func analyzeScreenshot(jpegData: Data) async throws -> ScreenAnalysis? {
- // Refresh context from local DB
- let context = await refreshContext()
-
- // Build prompt with context + history
- let historyText = formatHistory()
- var promptParts: [String] = []
- if !context.isEmpty {
- promptParts.append(context)
- }
- if !historyText.isEmpty {
- promptParts.append(historyText)
- }
- promptParts.append("Now analyze this new screenshot:")
- let prompt = promptParts.joined(separator: "\n\n")
-
- // Get current system prompt from settings
- let currentSystemPrompt = await systemPrompt
-
- // Build response schema
- let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema(
- type: "object",
- properties: [
- "status": .init(type: "string", enum: ["focused", "distracted"], description: "Whether the user is focused or distracted"),
- "app_or_site": .init(type: "string", enum: nil, description: "The app or website visible"),
- "description": .init(type: "string", enum: nil, description: "Brief description of what's on screen"),
- "message": .init(type: "string", enum: nil, description: "Coaching message")
- ],
- required: ["status", "app_or_site", "description"]
+ private func analyzeScreenshot(jpegData: Data, appName: String, windowTitle: String?) async throws -> ScreenAnalysis? {
+ let base64 = jpegData.base64EncodedString()
+ let result = try await backendService.analyzeFocus(
+ imageBase64: base64,
+ appName: appName,
+ windowTitle: windowTitle ?? ""
)
-
- let responseText = try await geminiClient.sendRequest(
- prompt: prompt,
- imageData: jpegData,
- systemPrompt: currentSystemPrompt,
- responseSchema: responseSchema
- )
-
- return try JSONDecoder().decode(ScreenAnalysis.self, from: Data(responseText.utf8))
+ return result
}
// MARK: - Storage
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift
index bcc6a6f1e6..2e0c671820 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift
@@ -17,7 +17,7 @@ actor MemoryAssistant: ProactiveAssistant {
// MARK: - Properties
- private let geminiClient: GeminiClient
+ private let backendService: BackendProactiveService
private var isRunning = false
private var lastAnalysisTime: Date = .distantPast
private var previousMemories: [ExtractedMemory] = [] // Last 20 extracted memories for deduplication
@@ -28,15 +28,6 @@ actor MemoryAssistant: ProactiveAssistant {
private let frameSignal: AsyncStream
private let frameSignalContinuation: AsyncStream.Continuation
- /// Get the current system prompt from settings (accessed on MainActor for thread safety)
- private var systemPrompt: String {
- get async {
- await MainActor.run {
- MemoryAssistantSettings.shared.analysisPrompt
- }
- }
- }
-
/// Get the extraction interval from settings
private var extractionInterval: TimeInterval {
get async {
@@ -57,9 +48,8 @@ actor MemoryAssistant: ProactiveAssistant {
// MARK: - Initialization
- init(apiKey: String? = nil) throws {
- // Use Gemini 3 Pro for better memory extraction quality
- self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest")
+ init(backendService: BackendProactiveService) {
+ self.backendService = backendService
let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1))
self.frameSignal = stream
@@ -340,61 +330,40 @@ actor MemoryAssistant: ProactiveAssistant {
}
private func extractMemories(from jpegData: Data, appName: String) async throws -> MemoryExtractionResult? {
- // Build context with previous memories for deduplication
- var prompt = "Analyze this screenshot from \(appName).\n\n"
+ let base64 = autoreleasepool { jpegData.base64EncodedString() }
+ let backendResult = try await backendService.extractMemories(
+ imageBase64: base64,
+ appName: appName,
+ windowTitle: ""
+ )
- if !previousMemories.isEmpty {
- prompt += "RECENTLY EXTRACTED MEMORIES (do not re-extract these or semantically similar ones):\n"
- for (index, memory) in previousMemories.enumerated() {
- prompt += "\(index + 1). [\(memory.category.rawValue)] \(memory.content)\n"
+ // Parse backend response into MemoryExtractionResult
+ let memories: [ExtractedMemory] = backendResult.memories.compactMap { dict in
+ guard let content = dict["content"] as? String, !content.isEmpty else { return nil }
+ let categoryStr = dict["category"] as? String ?? "system"
+ let category: ExtractedMemoryCategory = categoryStr == "interesting" ? .interesting : .system
+ let sourceApp = dict["source_app"] as? String ?? appName
+ let confidence: Double
+ if let confValue = dict["confidence"] as? Double {
+ confidence = confValue
+ } else if let confInt = dict["confidence"] as? Int {
+ confidence = Double(confInt)
+ } else {
+ confidence = 0.5
}
- prompt += "\nLook for NEW memories that are NOT already in the list above."
- } else {
- prompt += "Look for memories to extract (system facts about the user, or interesting wisdom from others)."
+ return ExtractedMemory(
+ content: content,
+ category: category,
+ sourceApp: sourceApp,
+ confidence: confidence
+ )
}
- // Get current system prompt from settings
- let currentSystemPrompt = await systemPrompt
-
- // Build response schema for memory extraction
- let memoryProperties: [String: GeminiRequest.GenerationConfig.ResponseSchema.Property] = [
- "content": .init(type: "string", description: "The memory content (max 15 words)"),
- "category": .init(type: "string", enum: ["system", "interesting"], description: "Memory category"),
- "source_app": .init(type: "string", description: "App where memory was found"),
- "confidence": .init(type: "number", description: "Confidence score 0.0-1.0")
- ]
-
- let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema(
- type: "object",
- properties: [
- "has_new_memory": .init(type: "boolean", description: "True if new memories were found"),
- "memories": .init(
- type: "array",
- description: "Array of extracted memories (0-3 max)",
- items: .init(
- type: "object",
- properties: memoryProperties,
- required: ["content", "category", "source_app", "confidence"]
- )
- ),
- "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"),
- "current_activity": .init(type: "string", description: "High-level description of user's activity")
- ],
- required: ["has_new_memory", "memories", "context_summary", "current_activity"]
+ return MemoryExtractionResult(
+ hasNewMemory: !memories.isEmpty,
+ memories: memories,
+ contextSummary: "Analyzed \(appName)",
+ currentActivity: ""
)
-
- do {
- let responseText = try await geminiClient.sendRequest(
- prompt: prompt,
- imageData: jpegData,
- systemPrompt: currentSystemPrompt,
- responseSchema: responseSchema
- )
-
- return try JSONDecoder().decode(MemoryExtractionResult.self, from: Data(responseText.utf8))
- } catch {
- logError("Memory analysis error", error: error)
- return nil
- }
}
}
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift
index 8df5b2fcc7..512c464777 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift
@@ -1,7 +1,7 @@
import Foundation
-/// Task extraction assistant that identifies tasks and action items from screen content
-/// Uses single-stage Gemini tool calling with vector + FTS5 search for deduplication
+/// Task extraction assistant that identifies tasks and action items from screen content.
+/// Phase 2: sends screenshots to backend via WebSocket, receives structured task results.
actor TaskAssistant: ProactiveAssistant {
// MARK: - ProactiveAssistant Protocol
@@ -18,7 +18,7 @@ actor TaskAssistant: ProactiveAssistant {
// MARK: - Properties
- private let geminiClient: GeminiClient
+ private let backendService: BackendProactiveService
private var isRunning = false
private var previousTasks: [ExtractedTask] = [] // Last 10 extracted tasks for context
private let maxPreviousTasks = 10
@@ -41,11 +41,6 @@ actor TaskAssistant: ProactiveAssistant {
/// Timestamp of last context switch yield, for throttling rapid switches
private var lastContextSwitchYieldTime: Date = .distantPast
- // Cached goals (refreshed every 5 minutes)
- private var cachedGoals: [Goal] = []
- private var lastGoalsRefresh: Date = .distantPast
- private let goalsRefreshInterval: TimeInterval = 300
-
// MARK: - Due Date Helpers
/// Parse an inferred deadline string into a Date, or default to end of today.
@@ -114,15 +109,6 @@ actor TaskAssistant: ProactiveAssistant {
return calendar.date(bySettingHour: 23, minute: 59, second: 0, of: startOfDay) ?? startOfDay
}
- /// Get the current system prompt from settings (accessed on MainActor for thread safety)
- private var systemPrompt: String {
- get async {
- await MainActor.run {
- TaskAssistantSettings.shared.analysisPrompt
- }
- }
- }
-
/// Get the extraction interval from settings
private var extractionInterval: TimeInterval {
get async {
@@ -143,9 +129,8 @@ actor TaskAssistant: ProactiveAssistant {
// MARK: - Initialization
- init(apiKey: String? = nil) throws {
- // Use Gemini 3 Pro for better task extraction quality
- self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest")
+ init(backendService: BackendProactiveService) {
+ self.backendService = backendService
let (stream, continuation) = AsyncStream.makeStream(of: TriggerEvent.self, bufferingPolicy: .bufferingNewest(1))
self.triggerStream = stream
@@ -221,11 +206,17 @@ actor TaskAssistant: ProactiveAssistant {
// MARK: - Test Analysis (for test runner)
- /// Run the extraction pipeline on arbitrary JPEG data without side effects (no saving, no events).
- /// Used by the test runner to replay past screenshots.
- /// Returns (result, searchCount) where searchCount is the number of search tool calls made.
+ /// Run extraction via backend for test runner. Returns (result, 0) for compatibility.
func testAnalyze(jpegData: Data, appName: String) async throws -> (TaskExtractionResult?, Int) {
- return try await extractTaskSingleStage(from: jpegData, appName: appName)
+ let base64 = autoreleasepool { jpegData.base64EncodedString() }
+ let backendResult = try await backendService.extractTasks(
+ imageBase64: base64, appName: appName, windowTitle: ""
+ )
+ if backendResult.tasks.isEmpty {
+ return (TaskExtractionResult(hasNewTask: false, task: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0)
+ }
+ let result = parseBackendTask(backendResult.tasks[0], appName: appName)
+ return (result, 0)
}
// MARK: - ProactiveAssistant Protocol Methods
@@ -579,7 +570,7 @@ actor TaskAssistant: ProactiveAssistant {
latestFrame = nil
}
- // MARK: - Single-Stage Analysis with Tool Calling
+ // MARK: - Backend Analysis (Phase 2 thin client)
private func processFrame(_ frame: CapturedFrame) async {
let enabled = await isEnabled
@@ -590,613 +581,88 @@ actor TaskAssistant: ProactiveAssistant {
log("Task: Analyzing frame from \(frame.appName)...")
do {
- let (result, searchCount) = try await extractTaskSingleStage(from: frame.jpegData, appName: frame.appName)
- guard let result = result else {
- log("Task: Analysis returned no result")
- return
- }
-
- log("Task: Analysis complete - hasNewTask: \(result.hasNewTask), context: \(result.contextSummary), searches: \(searchCount)")
+ let base64 = autoreleasepool { frame.jpegData.base64EncodedString() }
+ let backendResult = try await backendService.extractTasks(
+ imageBase64: base64,
+ appName: frame.appName,
+ windowTitle: frame.windowTitle ?? ""
+ )
- await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle) { type, data in
+ let sendEvent: (String, [String: Any]) -> Void = { type, data in
Task { @MainActor in
AssistantCoordinator.shared.sendEvent(type: type, data: data)
}
}
- } catch {
- logError("Task extraction error", error: error)
- }
- }
-
- /// Loop-based extraction: image analysis + iterative tool calling for search + terminal tool for decision
- /// Returns (result, searchCount) where searchCount is the number of search tool calls made.
- private func extractTaskSingleStage(from jpegData: Data, appName: String) async throws -> (TaskExtractionResult?, Int) {
- // 1. Gather context
- let context = await refreshContext()
-
- // 2. Build prompt with injected context
- let dateFormatter = DateFormatter()
- dateFormatter.dateFormat = "yyyy-MM-dd (EEEE)"
- let todayStr = dateFormatter.string(from: Date())
-
- var prompt = "Screenshot from \(appName). Today is \(todayStr). Analyze this screenshot for any unaddressed request directed at the user.\n\n"
-
- // For messaging apps, add an extra reminder about conversation analysis
- let messagingApps: Set = ["Telegram", "WhatsApp", "\u{200E}WhatsApp", "Messages", "Slack", "Discord"]
- if messagingApps.contains(appName) {
- prompt += """
- REMINDER — THIS IS A MESSAGING APP:
- - If this screenshot shows a chat sidebar/conversation list rather than an open conversation, SKIP entirely.
- - If it shows an open conversation, read the FULL conversation flow between the user and the other person.
- - LEFT-SIDE messages = from the other person. RIGHT-SIDE/colored = from the user.
- - PRIORITY: Look for where the user AGREED or COMMITTED to doing something the other person asked.
- Example: Other person says "Can you send me the report?" → User replies "Sure, will do" → Extract task: "Send [person] the report"
- - ALSO: Look for incoming requests the user hasn't responded to yet.
- - The task title should describe what was asked for, naming the other person in the conversation.
-
- """
- }
-
- // Inject AI user profile for context
- if let profile = await AIUserProfileService.shared.getLatestProfile() {
- prompt += "USER PROFILE (who this user is — use for context, not as a task source):\n"
- prompt += profile.profileText + "\n\n"
- }
-
- if !context.activeTasks.isEmpty {
- // Get score range for context
- let scoreRange = try? await ActionItemStorage.shared.getRelevanceScoreRange()
- let rangeStr = scoreRange.map { "Score range: \($0.min)–\($0.max). " } ?? ""
-
- prompt += "ACTIVE TASKS (user is already tracking these — each has a relevance_score where 1 = most important, higher numbers = less important):\n"
- prompt += "\(rangeStr)Use these scores to place any new task appropriately.\n"
- for (i, task) in context.activeTasks.enumerated() {
- let pri = task.priority.map { " [\($0)]" } ?? ""
- let score = task.relevanceScore.map { " [score:\($0)]" } ?? ""
- prompt += "\(i + 1).\(score) \(task.description)\(pri)\n"
- }
- prompt += "\n"
- }
-
- if !context.completedTasks.isEmpty {
- prompt += "RECENTLY COMPLETED TASKS (user engaged with these — this is the kind of task the user finds valuable. Extract similar types of tasks, just not exact duplicates of these specific ones):\n"
- for (i, task) in context.completedTasks.enumerated() {
- prompt += "\(i + 1). \(task.description)\n"
- }
- prompt += "\n"
- }
-
- if !context.deletedTasks.isEmpty {
- prompt += "USER-DELETED TASKS (user explicitly rejected these — do not re-extract similar):\n"
- for (i, task) in context.deletedTasks.enumerated() {
- prompt += "\(i + 1). \(task.description)\n"
- }
- prompt += "\n"
- }
-
- if !context.goals.isEmpty {
- prompt += "ACTIVE GOALS:\n"
- for (i, goal) in context.goals.enumerated() {
- prompt += "\(i + 1). \(goal.title)"
- if let desc = goal.description {
- prompt += " — \(desc)"
- }
- prompt += "\n"
- }
- prompt += "\n"
- }
- prompt += """
- Analyze this screenshot. If you see a potential request, search for duplicates first.
- If there is clearly no request on screen (~90% of screenshots), call no_task_found immediately.
- """
-
- // 3. Define 5 tools
- let tools = GeminiTool(functionDeclarations: [
- GeminiTool.FunctionDeclaration(
- name: "search_similar",
- description: "Search for semantically similar existing tasks using vector similarity. Call this when you see a potential request and want to check for duplicates.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "query": .init(type: "string", description: "A concise description of the potential task to search for")
- ],
- required: ["query"]
+ if backendResult.tasks.isEmpty {
+ let result = TaskExtractionResult(
+ hasNewTask: false, task: nil,
+ contextSummary: "Analyzed \(frame.appName)",
+ currentActivity: ""
)
- ),
- GeminiTool.FunctionDeclaration(
- name: "search_keywords",
- description: "Search for existing tasks matching specific keywords. Use this for precise keyword-based matching complementing vector search.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "query": .init(type: "string", description: "Keywords to search for in existing tasks")
- ],
- required: ["query"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "no_task_found",
- description: "Call this when there is no actionable request on screen. This is the most common outcome (~90% of screenshots). Use for: code editors, terminals, settings, media players, dashboards, or any screen without a direct request from another person or AI.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "context_summary": .init(type: "string", description: "Brief summary of what the user is looking at"),
- "current_activity": .init(type: "string", description: "What the user is actively doing")
- ],
- required: ["context_summary", "current_activity"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "extract_task",
- description: "Extract a new task that is not already tracked. Call ONLY after searching for duplicates. All fields are required.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "title": .init(type: "string", description: "Verb-first task title, 6–15 words. MUST name a specific person/project/artifact and a concrete action. If you can't write 6+ specific words, call no_task_found instead."),
- "description": .init(type: "string", description: "Additional context about the task. Empty string if none."),
- "priority": .init(type: "string", description: "Task priority", enumValues: ["high", "medium", "low"]),
- "tags": .init(type: "array", description: "1-3 relevant tags", items: .init(type: "string")),
- "source_app": .init(type: "string", description: "App where the task was found"),
- "inferred_deadline": .init(type: "string", description: "Deadline in yyyy-MM-dd format (e.g. '2025-10-04'). Resolve relative references like 'Thursday' or 'next week' to an actual date. Empty string if no deadline."),
- "confidence": .init(type: "number", description: "Confidence score 0.0-1.0"),
- "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"),
- "current_activity": .init(type: "string", description: "What the user is actively doing"),
- "source_category": .init(type: "string", description: "Where the task originated", enumValues: ["direct_request", "self_generated", "calendar_driven", "reactive", "external_system", "other"]),
- "source_subcategory": .init(type: "string", description: "Specific origin within category", enumValues: ["message", "meeting", "mention", "commitment", "idea", "reminder", "goal_subtask", "event_prep", "recurring", "deadline", "error", "notification", "observation", "project_tool", "alert", "documentation", "other"]),
- "relevance_score": .init(type: "integer", description: "Where this task ranks relative to existing tasks. Look at the relevance_score values of existing active tasks and assign a score that places this task appropriately. 1 = most important/urgent, higher numbers = less important. Must be a positive integer.")
- ],
- required: ["title", "description", "priority", "tags", "source_app", "inferred_deadline", "confidence", "context_summary", "current_activity", "source_category", "source_subcategory", "relevance_score"]
- )
- ),
- GeminiTool.FunctionDeclaration(
- name: "reject_task",
- description: "Reject task extraction — the potential task is a duplicate, already completed, or was previously rejected by the user. Call after searching confirms this.",
- parameters: GeminiTool.FunctionDeclaration.Parameters(
- type: "object",
- properties: [
- "reason": .init(type: "string", description: "Why this task was rejected (e.g. 'duplicate of existing active task', 'already completed')"),
- "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"),
- "current_activity": .init(type: "string", description: "What the user is actively doing")
- ],
- required: ["reason", "context_summary", "current_activity"]
- )
- )
- ])
-
- // 4. Get system prompt
- let currentSystemPrompt = await systemPrompt
-
- // 5. Build initial contents
- // Wrap base64 encoding in autoreleasepool — Swift concurrency doesn't
- // drain autorelease pools, causing bridged NSString objects to accumulate.
- var contents: [GeminiImageToolRequest.Content] = autoreleasepool {
- let base64Data = jpegData.base64EncodedString()
- return [
- GeminiImageToolRequest.Content(
- role: "user",
- parts: [
- GeminiImageToolRequest.Part(text: prompt),
- GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64Data)
- ]
- )
- ]
- }
-
- // 6. Tool-calling loop (max 5 iterations)
- var searchCount = 0
-
- for iteration in 0..<5 {
- let result = try await geminiClient.sendImageToolLoop(
- contents: contents,
- systemPrompt: currentSystemPrompt,
- tools: [tools],
- forceToolCall: iteration == 0
- )
-
- guard let toolCall = result.toolCalls.first else {
- log("Task: No tool call received on iteration \(iteration), breaking")
- break
- }
-
- switch toolCall.name {
- case "no_task_found":
- let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No task on screen"
- let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown"
- log("Task: no_task_found — \(contextSummary)")
- return (TaskExtractionResult(
- hasNewTask: false,
- task: nil,
- contextSummary: contextSummary,
- currentActivity: currentActivity
- ), searchCount)
-
- case "extract_task":
- let title = toolCall.arguments["title"] as? String ?? ""
- let contextSummary = toolCall.arguments["context_summary"] as? String ?? ""
- let currentActivity = toolCall.arguments["current_activity"] as? String ?? ""
-
- // --- Hard validation: reject vague titles and ask the model to retry ---
- let titleWords = title.split(separator: " ").count
- let validationError = Self.validateTaskTitle(title, wordCount: titleWords)
- if let error = validationError {
- log("Task: Title rejected (\(error)): \"\(title)\"")
-
- // Feed rejection back into the loop so the model can retry with more specifics
- contents.append(GeminiImageToolRequest.Content(
- role: "model",
- parts: [GeminiImageToolRequest.Part(
- functionCall: .init(name: toolCall.name, args: toolCall.arguments as? [String: String] ?? ["title": title]),
- thoughtSignature: toolCall.thoughtSignature
- )]
- ))
- contents.append(GeminiImageToolRequest.Content(
- role: "user",
- parts: [GeminiImageToolRequest.Part(functionResponse: .init(
- name: toolCall.name,
- response: .init(result: """
- REJECTED: \(error). \
- Your title was: "\(title)" (\(titleWords) words). \
- Either rewrite with 6+ words including a specific person/project name and concrete action, \
- or call no_task_found if you cannot be more specific.
- """)
- ))]
- ))
- continue
- }
-
- let description = toolCall.arguments["description"] as? String
- let priorityStr = toolCall.arguments["priority"] as? String ?? "medium"
- let priority = TaskPriority(rawValue: priorityStr) ?? .medium
- let tags: [String]
- if let tagArray = toolCall.arguments["tags"] as? [Any] {
- tags = tagArray.compactMap { $0 as? String }
- } else {
- tags = []
- }
- let sourceApp = toolCall.arguments["source_app"] as? String ?? appName
- let inferredDeadline = toolCall.arguments["inferred_deadline"] as? String
- let confidence: Double
- if let confValue = toolCall.arguments["confidence"] as? Double {
- confidence = confValue
- } else if let confInt = toolCall.arguments["confidence"] as? Int {
- confidence = Double(confInt)
- } else {
- confidence = 0.5
- }
- let sourceCategory = toolCall.arguments["source_category"] as? String ?? "other"
- let sourceSubcategory = toolCall.arguments["source_subcategory"] as? String ?? "other"
- let relevanceScore: Int?
- if let scoreValue = toolCall.arguments["relevance_score"] as? Int {
- relevanceScore = scoreValue
- } else if let scoreDouble = toolCall.arguments["relevance_score"] as? Double {
- relevanceScore = Int(scoreDouble)
- } else {
- relevanceScore = nil
- }
-
- let task = ExtractedTask(
- title: title,
- description: description?.isEmpty == true ? nil : description,
- priority: priority,
- sourceApp: sourceApp,
- inferredDeadline: inferredDeadline?.isEmpty == true ? nil : inferredDeadline,
- confidence: confidence,
- tags: tags,
- sourceCategory: sourceCategory,
- sourceSubcategory: sourceSubcategory,
- relevanceScore: relevanceScore
- )
-
- log("Task: extract_task — \"\(title)\" (confidence: \(confidence), priority: \(priorityStr), score: \(relevanceScore.map { String($0) } ?? "nil"))")
- return (TaskExtractionResult(
- hasNewTask: true,
- task: task,
- contextSummary: contextSummary,
- currentActivity: currentActivity
- ), searchCount)
-
- case "reject_task":
- let reason = toolCall.arguments["reason"] as? String ?? "Unknown reason"
- let contextSummary = toolCall.arguments["context_summary"] as? String ?? ""
- let currentActivity = toolCall.arguments["current_activity"] as? String ?? ""
- log("Task: reject_task — \(reason)")
- return (TaskExtractionResult(
- hasNewTask: false,
- task: nil,
- contextSummary: contextSummary,
- currentActivity: currentActivity
- ), searchCount)
-
- case "search_similar":
- let query = toolCall.arguments["query"] as? String ?? ""
- searchCount += 1
- log("Task: search_similar query: \"\(query)\"")
- let searchResults = await executeVectorSearch(query: query)
- log("Task: Vector search returned \(searchResults.count) results")
-
- let searchResultsJson: String
- if let data = try? JSONEncoder().encode(searchResults),
- let json = String(data: data, encoding: .utf8) {
- searchResultsJson = json
- } else {
- searchResultsJson = "[]"
- }
-
- // Append model's tool call + function response to contents
- contents.append(GeminiImageToolRequest.Content(
- role: "model",
- parts: [GeminiImageToolRequest.Part(
- functionCall: .init(name: toolCall.name, args: ["query": query]),
- thoughtSignature: toolCall.thoughtSignature
- )]
- ))
- contents.append(GeminiImageToolRequest.Content(
- role: "user",
- parts: [GeminiImageToolRequest.Part(functionResponse: .init(
- name: toolCall.name,
- response: .init(result: searchResultsJson)
- ))]
- ))
- continue
-
- case "search_keywords":
- let query = toolCall.arguments["query"] as? String ?? ""
- searchCount += 1
- log("Task: search_keywords query: \"\(query)\"")
- let searchResults = await executeKeywordSearch(query: query)
- log("Task: Keyword search returned \(searchResults.count) results")
-
- let searchResultsJson: String
- if let data = try? JSONEncoder().encode(searchResults),
- let json = String(data: data, encoding: .utf8) {
- searchResultsJson = json
- } else {
- searchResultsJson = "[]"
- }
-
- // Append model's tool call + function response to contents
- contents.append(GeminiImageToolRequest.Content(
- role: "model",
- parts: [GeminiImageToolRequest.Part(
- functionCall: .init(name: toolCall.name, args: ["query": query]),
- thoughtSignature: toolCall.thoughtSignature
- )]
- ))
- contents.append(GeminiImageToolRequest.Content(
- role: "user",
- parts: [GeminiImageToolRequest.Part(functionResponse: .init(
- name: toolCall.name,
- response: .init(result: searchResultsJson)
- ))]
- ))
- continue
-
- default:
- log("Task: Unknown tool call: \(toolCall.name), breaking")
- break
- }
- }
-
- log("Task: Completed in \(searchCount) searches (loop exhausted without terminal tool)")
- return (nil, searchCount)
- }
-
- // MARK: - Title Validation
-
- /// Validates a task title for minimum specificity. Returns an error message if invalid, nil if OK.
- private static func validateTaskTitle(_ title: String, wordCount: Int) -> String? {
- let trimmed = title.trimmingCharacters(in: .whitespacesAndNewlines)
-
- // Must not be empty
- if trimmed.isEmpty {
- return "Title is empty"
- }
-
- // Minimum 6 words
- if wordCount < 6 {
- return "Title too short (\(wordCount) words, minimum 6)"
- }
-
- // Reject titles that are purely generic verbs with no specifics
- let genericPatterns: [String] = [
- "investigate", "check logs", "clean up", "look into",
- "look through", "update to", "fix the", "review the",
- "check the", "modify the", "track the"
- ]
- let lowered = trimmed.lowercased()
- for pattern in genericPatterns {
- // If the entire title is just a generic pattern (possibly with 1-2 filler words), reject
- if lowered == pattern || (wordCount <= 4 && lowered.hasPrefix(pattern)) {
- return "Title too generic (matches vague pattern '\(pattern)')"
+ log("Task: Analysis returned no tasks")
+ await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle, sendEvent: sendEvent)
+ return
}
- }
- // Must contain at least one capitalized proper noun (person, project, app name)
- // Heuristic: after the first word (verb), there should be at least one word starting with uppercase
- let words = trimmed.split(separator: " ")
- let hasProperNoun = words.dropFirst().contains { word in
- guard let first = word.first else { return false }
- return first.isUppercase
- }
- if !hasProperNoun {
- return "Title lacks a specific name (person, project, or app) — no proper nouns found after the verb"
- }
-
- return nil
- }
-
- // MARK: - Context & Search
+ log("Task: Analysis complete - \(backendResult.tasks.count) task(s)")
- /// Refresh context from local SQLite + cached goals
- private func refreshContext() async -> TaskExtractionContext {
- var topRelevanceTasks: [(id: Int64, description: String, priority: String?, relevanceScore: Int?)] = []
- var recentTasks: [(id: Int64, description: String, priority: String?, relevanceScore: Int?)] = []
- var completedTasks: [(id: Int64, description: String)] = []
- var deletedTasks: [(id: Int64, description: String)] = []
-
- // Query both action_items (promoted + manual) and staged_tasks for full context
- do {
- topRelevanceTasks = try await ActionItemStorage.shared.getTopRelevanceTasks(limit: 30)
- } catch {
- logError("Task: Failed to load top relevance tasks", error: error)
- }
-
- do {
- recentTasks = try await ActionItemStorage.shared.getRecentActiveTasks(limit: 30)
- } catch {
- logError("Task: Failed to load recent tasks", error: error)
- }
-
- // Also include staged tasks for dedup context
- do {
- let stagedTasks = try await StagedTaskStorage.shared.getAllStagedTasks(limit: 30)
- let stagedAsTuples = stagedTasks.map { task in
- (id: Int64(0), description: task.description, priority: task.priority, relevanceScore: task.relevanceScore)
+ for taskDict in backendResult.tasks {
+ let result = parseBackendTask(taskDict, appName: frame.appName)
+ await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle, sendEvent: sendEvent)
}
- recentTasks.append(contentsOf: stagedAsTuples)
- } catch {
- logError("Task: Failed to load staged tasks for context", error: error)
- }
-
- // Merge: top relevance tasks first, then recent ones not already included
- let topIds = Set(topRelevanceTasks.map { $0.id })
- let activeTasks = topRelevanceTasks + recentTasks.filter { !topIds.contains($0.id) }
-
- do {
- completedTasks = try await ActionItemStorage.shared.getRecentCompletedTasks(limit: 10)
- } catch {
- logError("Task: Failed to load completed tasks", error: error)
- }
-
- do {
- deletedTasks = try await ActionItemStorage.shared.getRecentDeletedTasks(limit: 10, deletedBy: "user")
} catch {
- logError("Task: Failed to load deleted tasks", error: error)
- }
-
- // Refresh goals if stale
- let timeSinceGoals = Date().timeIntervalSince(lastGoalsRefresh)
- if timeSinceGoals >= goalsRefreshInterval {
- do {
- cachedGoals = try await APIClient.shared.getGoals()
- lastGoalsRefresh = Date()
- log("Task: Refreshed \(cachedGoals.count) goals")
- } catch {
- logError("Task: Failed to refresh goals", error: error)
- }
+ logError("Task extraction error", error: error)
}
-
- return TaskExtractionContext(
- activeTasks: activeTasks,
- completedTasks: completedTasks,
- deletedTasks: deletedTasks,
- goals: cachedGoals
- )
}
- /// Execute vector similarity search
- private func executeVectorSearch(query: String) async -> [TaskSearchResult] {
- var results: [TaskSearchResult] = []
-
- do {
- let queryEmbedding = try await EmbeddingService.shared.embed(text: query)
- let vectorResults = await EmbeddingService.shared.searchSimilar(query: queryEmbedding, topK: 10)
-
- for result in vectorResults where result.similarity > 0.3 {
- if let record = try await ActionItemStorage.shared.getActionItem(id: result.id) {
- let status: String
- if record.deleted { status = "deleted" }
- else if record.completed { status = "completed" }
- else { status = "active" }
-
- results.append(TaskSearchResult(
- id: result.id,
- description: record.description,
- status: status,
- similarity: Double(result.similarity),
- matchType: "vector",
- relevanceScore: record.relevanceScore
- ))
- } else if let staged = try await StagedTaskStorage.shared.getStagedTask(id: result.id) {
- // Fallback: ID belongs to a staged task (shared embedding index)
- let status: String
- if staged.deleted { status = "deleted" }
- else if staged.completed { status = "completed" }
- else { status = "active" }
-
- results.append(TaskSearchResult(
- id: result.id,
- description: staged.description,
- status: status,
- similarity: Double(result.similarity),
- matchType: "vector",
- relevanceScore: staged.relevanceScore
- ))
- }
- }
- } catch {
- logError("Task: Vector search failed", error: error)
- }
+ /// Parse a raw task dict from the backend into a TaskExtractionResult.
+ private func parseBackendTask(_ dict: [String: Any], appName: String) -> TaskExtractionResult {
+ let title = dict["title"] as? String ?? ""
+ let description = dict["description"] as? String
+ let priorityStr = dict["priority"] as? String ?? "medium"
+ let priority = TaskPriority(rawValue: priorityStr) ?? .medium
+ let tags = (dict["tags"] as? [String]) ?? []
+ let sourceApp = dict["source_app"] as? String ?? appName
+ let inferredDeadline = dict["inferred_deadline"] as? String
+ let confidence: Double
+ if let confValue = dict["confidence"] as? Double {
+ confidence = confValue
+ } else if let confInt = dict["confidence"] as? Int {
+ confidence = Double(confInt)
+ } else {
+ confidence = 0.5
+ }
+ let sourceCategory = dict["source_category"] as? String ?? "other"
+ let sourceSubcategory = dict["source_subcategory"] as? String ?? "other"
+ let relevanceScore: Int?
+ if let scoreValue = dict["relevance_score"] as? Int {
+ relevanceScore = scoreValue
+ } else if let scoreDouble = dict["relevance_score"] as? Double {
+ relevanceScore = Int(scoreDouble)
+ } else {
+ relevanceScore = nil
+ }
+
+ let task = ExtractedTask(
+ title: title,
+ description: description?.isEmpty == true ? nil : description,
+ priority: priority,
+ sourceApp: sourceApp,
+ inferredDeadline: inferredDeadline?.isEmpty == true ? nil : inferredDeadline,
+ confidence: confidence,
+ tags: tags,
+ sourceCategory: sourceCategory,
+ sourceSubcategory: sourceSubcategory,
+ relevanceScore: relevanceScore
+ )
- return results.sorted { ($0.similarity ?? 0) > ($1.similarity ?? 0) }
+ return TaskExtractionResult(
+ hasNewTask: true,
+ task: task,
+ contextSummary: dict["context_summary"] as? String ?? "Analyzed \(appName)",
+ currentActivity: dict["current_activity"] as? String ?? ""
+ )
}
- /// Execute FTS5 keyword search (searches both action_items and staged_tasks)
- private func executeKeywordSearch(query: String) async -> [TaskSearchResult] {
- var results: [TaskSearchResult] = []
-
- do {
- let words = query.components(separatedBy: .whitespaces)
- .map { $0.filter { $0.isLetter || $0.isNumber } } // Strip FTS5 special chars (- : * " etc.)
- .filter { $0.count >= 3 }
- let ftsQuery = words.map { "\($0)*" }.joined(separator: " OR ")
-
- if !ftsQuery.isEmpty {
- // Search action_items (promoted + manual)
- let ftsResults = try await ActionItemStorage.shared.searchFTS(
- query: ftsQuery,
- limit: 10,
- includeCompleted: true,
- includeDeleted: true
- )
-
- for result in ftsResults {
- let status: String
- if result.deleted { status = "deleted" }
- else if result.completed { status = "completed" }
- else { status = "active" }
-
- results.append(TaskSearchResult(
- id: result.id,
- description: result.description,
- status: status,
- similarity: nil,
- matchType: "fts",
- relevanceScore: result.relevanceScore
- ))
- }
-
- // Also search staged_tasks
- let stagedResults = try await StagedTaskStorage.shared.searchFTS(
- query: ftsQuery,
- limit: 10
- )
- for result in stagedResults {
- results.append(TaskSearchResult(
- id: result.id,
- description: result.description,
- status: "active",
- similarity: nil,
- matchType: "fts",
- relevanceScore: result.relevanceScore
- ))
- }
- }
- } catch {
- logError("Task: FTS search failed", error: error)
- }
-
- return results
- }
}
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift
index 4618f99673..98b38a8761 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift
@@ -6,7 +6,7 @@ import Foundation
actor TaskDeduplicationService {
static let shared = TaskDeduplicationService()
- private var geminiClient: GeminiClient?
+ private var backendService: BackendProactiveService?
private var timer: Task?
private var isRunning = false
private var lastRunTime: Date?
@@ -17,13 +17,11 @@ actor TaskDeduplicationService {
private let cooldownSeconds: TimeInterval = 1800 // 30-min cooldown
private let minimumTaskCount = 3
- private init() {
- do {
- self.geminiClient = try GeminiClient(model: "gemini-pro-latest")
- } catch {
- log("TaskDedup: Failed to initialize GeminiClient: \(error)")
- self.geminiClient = nil
- }
+ private init() {}
+
+ /// Set the backend service for Phase 2 server-side deduplication.
+ func configure(backendService: BackendProactiveService) {
+ self.backendService = backendService
}
// MARK: - Lifecycle
@@ -67,210 +65,42 @@ actor TaskDeduplicationService {
// MARK: - Deduplication Logic
private func runDeduplication() async {
- guard let client = geminiClient else {
- log("TaskDedup: Skipping - Gemini client not initialized")
+ guard let service = backendService else {
+ log("TaskDedup: Skipping - backend service not configured")
return
}
lastRunTime = Date()
- log("TaskDedup: Starting deduplication run on staged tasks")
-
- // 1. Fetch staged tasks (not yet promoted to action items)
- let tasks: [TaskActionItem]
- do {
- let response = try await APIClient.shared.getStagedTasks(limit: 200)
- tasks = response.items
- } catch {
- log("TaskDedup: Failed to fetch staged tasks: \(error)")
- return
- }
-
- guard tasks.count >= minimumTaskCount else {
- log("TaskDedup: Only \(tasks.count) staged tasks, skipping (minimum: \(minimumTaskCount))")
- return
- }
-
- log("TaskDedup: Analyzing \(tasks.count) staged tasks for duplicates")
-
- // 2. Send all tasks to Gemini in a single call
- let totalDeleted = await analyzeAndDeleteDuplicates(tasks: tasks, client: client)
-
- log("TaskDedup: Run complete. Hard-deleted \(totalDeleted) duplicate staged tasks.")
- }
-
- private func analyzeAndDeleteDuplicates(tasks: [TaskActionItem], client: GeminiClient) async -> Int {
- // Build task list for prompt
- let taskDescriptions = tasks.map { task -> String in
- var parts = ["ID: \(task.id)", "Description: \(task.description)"]
- if let due = task.dueAt {
- parts.append("Due: \(ISO8601DateFormatter().string(from: due))")
- }
- if let priority = task.priority {
- parts.append("Priority: \(priority)")
- }
- if let source = task.source {
- parts.append("Source: \(source)")
- }
- parts.append("Created: \(ISO8601DateFormatter().string(from: task.createdAt))")
- return parts.joined(separator: "\n")
- }.joined(separator: "\n")
-
- let prompt = """
- Analyze the following tasks for semantic duplicates. Two tasks are duplicates if they \
- refer to the same action, even if worded differently.
-
- Tasks:
- \(taskDescriptions)
-
- For each group of duplicates, pick the best task to KEEP based on these criteria (in order):
- 1. Most descriptive/specific wording
- 2. Has a due date over one that doesn't
- 3. Higher priority set (high > medium > low > none)
- 4. More reliable source (manual > transcription > screenshot)
- 5. Most recently created
-
- Only flag tasks as duplicates if you are confident they refer to the same action. \
- When in doubt, do NOT flag as duplicates.
- """
-
- let systemPrompt = """
- You are a task deduplication assistant. You identify semantically duplicate tasks \
- and choose the best one to keep. Be conservative - only flag clear duplicates. \
- Return has_duplicates: false if no duplicates are found.
- """
-
- let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema(
- type: "object",
- properties: [
- "has_duplicates": .init(type: "boolean", description: "Whether any duplicate groups were found"),
- "duplicate_groups": .init(
- type: "array",
- description: "Groups of duplicate tasks",
- items: .init(
- type: "object",
- properties: [
- "keep_id": .init(type: "string", description: "ID of the task to keep"),
- "delete_ids": .init(
- type: "array",
- description: "IDs of tasks to delete",
- items: .init(type: "string", properties: nil, required: nil)
- ),
- "reason": .init(type: "string", description: "Why these tasks are duplicates and which was kept")
- ],
- required: ["keep_id", "delete_ids", "reason"]
- )
- )
- ],
- required: ["has_duplicates", "duplicate_groups"]
- )
-
- // Call Gemini
- let responseText: String
- do {
- responseText = try await client.sendRequest(
- prompt: prompt,
- systemPrompt: systemPrompt,
- responseSchema: responseSchema
- )
- } catch {
- log("TaskDedup: Gemini request failed: \(error)")
- return 0
- }
-
- // Parse response
- guard let data = responseText.data(using: .utf8) else {
- log("TaskDedup: Failed to convert response to data")
- return 0
- }
+ log("TaskDedup: Starting server-side deduplication")
- let result: DedupResponse
do {
- result = try JSONDecoder().decode(DedupResponse.self, from: data)
- } catch {
- log("TaskDedup: Failed to parse response: \(error)")
- return 0
- }
-
- guard result.hasDuplicates, !result.duplicateGroups.isEmpty else {
- log("TaskDedup: No duplicates found in batch of \(tasks.count) staged tasks")
- return 0
- }
+ let result = try await service.deduplicateTasks()
- // Validate and delete
- let validTaskIDs = Set(tasks.map { $0.id })
- let taskLookup = Dictionary(tasks.map { ($0.id, $0) }, uniquingKeysWith: { _, latest in latest })
- var deletedCount = 0
-
- for group in result.duplicateGroups {
- // Safety: verify all IDs exist in our input
- guard validTaskIDs.contains(group.keepId) else {
- log("TaskDedup: Skipping group - keep_id '\(group.keepId)' not in input set")
- continue
- }
-
- let validDeleteIds = group.deleteIds.filter { validTaskIDs.contains($0) }
- if validDeleteIds.count != group.deleteIds.count {
- log("TaskDedup: Some delete IDs not in input set, filtering")
+ if result.deletedIds.isEmpty {
+ log("TaskDedup: No duplicates found")
+ return
}
- guard !validDeleteIds.isEmpty else { continue }
-
- let keptTask = taskLookup[group.keepId]
+ log("TaskDedup: Server deleted \(result.deletedIds.count) duplicates. Reason: \(result.reason)")
- for deleteId in validDeleteIds {
- let deletedTask = taskLookup[deleteId]
-
- // Log to SQLite
+ // Log each deletion locally
+ for deleteId in result.deletedIds {
let logRecord = TaskDedupLogRecord(
deletedTaskId: deleteId,
- deletedDescription: deletedTask?.description ?? "unknown",
- keptTaskId: group.keepId,
- keptDescription: keptTask?.description ?? "unknown",
- reason: group.reason,
+ deletedDescription: "server-side dedup",
+ keptTaskId: "",
+ keptDescription: "",
+ reason: result.reason,
deletedAt: Date()
)
-
do {
try await ProactiveStorage.shared.insertDedupLogRecord(logRecord)
} catch {
log("TaskDedup: Failed to log deletion record: \(error)")
}
-
- // Hard-delete staged task from backend
- do {
- try await APIClient.shared.deleteStagedTask(id: deleteId)
- deletedCount += 1
- log("TaskDedup: Hard-deleted staged task '\(deletedTask?.description ?? deleteId)' (kept: '\(keptTask?.description ?? group.keepId)') - \(group.reason)")
- } catch {
- log("TaskDedup: Failed to delete staged task \(deleteId) on backend: \(error)")
- }
}
- }
-
- return deletedCount
- }
-}
-
-// MARK: - Response Models
-
-private struct DedupResponse: Codable {
- let hasDuplicates: Bool
- let duplicateGroups: [DuplicateGroup]
-
- enum CodingKeys: String, CodingKey {
- case hasDuplicates = "has_duplicates"
- case duplicateGroups = "duplicate_groups"
- }
-
- struct DuplicateGroup: Codable {
- let keepId: String
- let deleteIds: [String]
- let reason: String
-
- enum CodingKeys: String, CodingKey {
- case keepId = "keep_id"
- case deleteIds = "delete_ids"
- case reason
+ } catch {
+ log("TaskDedup: Server deduplication failed: \(error)")
}
}
}
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift
index a1aeca228f..ac922254ef 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift
@@ -7,7 +7,7 @@ import Foundation
actor TaskPrioritizationService {
static let shared = TaskPrioritizationService()
- private var geminiClient: GeminiClient?
+ private var backendService: BackendProactiveService?
private var timer: Task?
private var isRunning = false
private(set) var isScoringInProgress = false
@@ -29,13 +29,6 @@ actor TaskPrioritizationService {
// Restore persisted timestamps
self.lastFullRunTime = UserDefaults.standard.object(forKey: Self.fullRunKey) as? Date
- do {
- self.geminiClient = try GeminiClient(model: "gemini-pro-latest")
- } catch {
- log("TaskPrioritize: Failed to initialize GeminiClient: \(error)")
- self.geminiClient = nil
- }
-
if let last = self.lastFullRunTime {
let hoursAgo = Int(Date().timeIntervalSince(last) / 3600)
log("TaskPrioritize: Last full rescore was \(hoursAgo)h ago")
@@ -44,6 +37,11 @@ actor TaskPrioritizationService {
}
}
+ /// Set the backend service for Phase 2 server-side reranking.
+ func configure(backendService: BackendProactiveService) {
+ self.backendService = backendService
+ }
+
// MARK: - Lifecycle
func start() {
@@ -101,187 +99,48 @@ actor TaskPrioritizationService {
// MARK: - Full Rescore (Hourly)
- /// Send ALL staged tasks to Gemini, get back only the ones that need re-ranking
+ /// Request server-side reranking via backend WebSocket.
private func runFullRescore() async {
guard !isScoringInProgress else {
log("TaskPrioritize: [FULL] Skipping — scoring already in progress")
return
}
- guard let client = geminiClient else {
- log("TaskPrioritize: Skipping full rescore — Gemini client not initialized")
+ guard let service = backendService else {
+ log("TaskPrioritize: Skipping full rescore — backend service not configured")
return
}
isScoringInProgress = true
defer { isScoringInProgress = false }
- log("TaskPrioritize: [FULL] Starting hourly rescore of staged tasks")
+ log("TaskPrioritize: [FULL] Starting server-side rescore")
- // Get ALL staged tasks (not action_items)
- let allTasks: [TaskActionItem]
do {
- allTasks = try await StagedTaskStorage.shared.getAllStagedTasks(limit: 10000)
- } catch {
- log("TaskPrioritize: [FULL] Failed to fetch staged tasks: \(error)")
- return
- }
-
- log("TaskPrioritize: [FULL] Found \(allTasks.count) staged tasks")
+ let result = try await service.rerankTasks()
- guard allTasks.count >= minimumTaskCount else {
- log("TaskPrioritize: [FULL] Only \(allTasks.count) staged tasks, skipping")
- lastFullRunTime = Date()
- return
- }
-
- // Fetch context
- let (referenceContext, profile, goals) = await fetchContext()
-
- // Build the current ranking: tasks ordered by relevanceScore ASC (1 = top)
- let sortedTasks = allTasks.sorted { a, b in
- let scoreA = a.relevanceScore ?? Int.max
- let scoreB = b.relevanceScore ?? Int.max
- return scoreA < scoreB
- }
-
- // Build task list for the prompt with current positions
- let taskLines = sortedTasks.enumerated().map { (index, task) -> String in
- var parts = ["\(index + 1). [id:\(task.id)] \(task.description)"]
- if let priority = task.priority {
- parts.append("[\(priority)]")
- }
- if let due = task.dueAt {
- let formatter = ISO8601DateFormatter()
- parts.append("[due: \(formatter.string(from: due))]")
+ if result.updatedTasks.isEmpty {
+ log("TaskPrioritize: [FULL] No tasks need re-ranking, current order is good")
+ lastFullRunTime = Date()
+ return
}
- return parts.joined(separator: " ")
- }.joined(separator: "\n")
-
- // Build context sections
- var contextParts: [String] = []
- if let profile = profile, !profile.isEmpty {
- contextParts.append("USER PROFILE:\n\(profile)")
- }
+ // Parse server response into reranking tuples
+ let reranks: [(backendId: String, newPosition: Int)] = result.updatedTasks.compactMap { dict in
+ guard let id = dict["id"] as? String,
+ let newPos = dict["new_position"] as? Int else { return nil }
+ return (backendId: id, newPosition: newPos)
+ }
- if !goals.isEmpty {
- let goalsText = goals.enumerated().map { (i, goal) in
- var text = "\(i + 1). \(goal.title)"
- if let desc = goal.description {
- text += " — \(desc)"
+ if !reranks.isEmpty {
+ do {
+ try await StagedTaskStorage.shared.applySelectiveReranking(reranks)
+ log("TaskPrioritize: [FULL] Applied server re-ranking for \(reranks.count) staged tasks")
+ } catch {
+ log("TaskPrioritize: [FULL] Failed to apply re-ranking: \(error)")
}
- text += " (\(Int(goal.progress))% complete)"
- return text
- }.joined(separator: "\n")
- contextParts.append("ACTIVE GOALS:\n\(goalsText)")
- }
-
- if !referenceContext.isEmpty {
- contextParts.append(referenceContext)
- }
-
- let contextSection = contextParts.isEmpty ? "" : contextParts.joined(separator: "\n\n") + "\n\n"
-
- let prompt = """
- Review the user's staged task list (ranked 1 = most important, \(sortedTasks.count) = least important).
-
- Identify tasks that are MISRANKED — tasks whose current position doesn't match their actual importance.
- Only return tasks that need to move. Do NOT return tasks that are already well-positioned.
-
- Consider:
- 1. Alignment with the user's goals and current priorities
- 2. Time urgency (due date proximity)
- 3. Actionability — specific tasks rank higher than vague ones
- 4. Real-world importance (financial, health, commitments to others)
- 5. Most AI-extracted tasks are noise — push vague/irrelevant tasks down
-
- \(contextSection)CURRENT TASK RANKING (1 = most important):
- \(taskLines)
-
- Return ONLY the tasks that need re-ranking, with their new position numbers.
- New positions should be relative to the current list size (1 to \(sortedTasks.count)).
- """
-
- let systemPrompt = """
- You are a task prioritization assistant. You review a ranked task list and identify \
- tasks that are misranked. Be selective — only return tasks that genuinely need to move. \
- If the ranking looks reasonable, return an empty list. Be decisive about pushing noise \
- and vague tasks down and promoting urgent, goal-aligned tasks up.
- """
-
- let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema(
- type: "object",
- properties: [
- "reranked_tasks": .init(
- type: "array",
- description: "Tasks that need to be moved, with new positions",
- items: .init(
- type: "object",
- properties: [
- "task_id": .init(type: "string", description: "The task ID"),
- "new_position": .init(type: "integer", description: "New rank position (1 = most important)")
- ],
- required: ["task_id", "new_position"]
- )
- ),
- "reasoning": .init(type: "string", description: "Brief explanation of major ranking changes")
- ],
- required: ["reranked_tasks", "reasoning"]
- )
-
- log("TaskPrioritize: [FULL] Sending \(sortedTasks.count) staged tasks to Gemini")
-
- let responseText: String
- do {
- responseText = try await client.sendRequest(
- prompt: prompt,
- systemPrompt: systemPrompt,
- responseSchema: responseSchema
- )
- } catch {
- log("TaskPrioritize: [FULL] Gemini request failed: \(error)")
- return
- }
-
- let truncated = responseText.prefix(500)
- log("TaskPrioritize: [FULL] Gemini response (\(responseText.count) chars): \(truncated)\(responseText.count > 500 ? "..." : "")")
-
- guard let data = responseText.data(using: .utf8) else {
- log("TaskPrioritize: [FULL] Failed to convert response to data")
- return
- }
-
- let result: ReRankingResponse
- do {
- result = try JSONDecoder().decode(ReRankingResponse.self, from: data)
- } catch {
- log("TaskPrioritize: [FULL] Failed to parse re-ranking response: \(error)")
- return
- }
-
- log("TaskPrioritize: [FULL] Gemini returned \(result.rerankedTasks.count) tasks to re-rank")
- if !result.reasoning.isEmpty {
- log("TaskPrioritize: [FULL] Reasoning: \(result.reasoning.prefix(300))")
- }
-
- // Validate: only keep task IDs that exist in our list
- let validIds = Set(allTasks.map { $0.id })
- let validReranks = result.rerankedTasks.filter { validIds.contains($0.taskId) }
-
- if validReranks.count != result.rerankedTasks.count {
- log("TaskPrioritize: [FULL] Filtered out \(result.rerankedTasks.count - validReranks.count) invalid task IDs")
- }
-
- if !validReranks.isEmpty {
- let reranks = validReranks.map { (backendId: $0.taskId, newPosition: $0.newPosition) }
- do {
- try await StagedTaskStorage.shared.applySelectiveReranking(reranks)
- log("TaskPrioritize: [FULL] Applied selective re-ranking for \(validReranks.count) staged tasks")
- } catch {
- log("TaskPrioritize: [FULL] Failed to apply re-ranking: \(error)")
}
- } else {
- log("TaskPrioritize: [FULL] No tasks need re-ranking, current order is good")
+ } catch {
+ log("TaskPrioritize: [FULL] Server reranking failed: \(error)")
}
lastFullRunTime = Date()
@@ -304,68 +163,4 @@ actor TaskPrioritizationService {
}
}
- // MARK: - Shared Context Fetching
-
- private func fetchContext() async -> (referenceContext: String, profile: String?, goals: [Goal]) {
- let userProfile = await AIUserProfileService.shared.getLatestProfile()
-
- let goals: [Goal]
- do {
- goals = try await APIClient.shared.getGoals()
- } catch {
- log("TaskPrioritize: Failed to fetch goals: \(error)")
- goals = []
- }
-
- let referenceTasks: [TaskActionItem]
- do {
- referenceTasks = try await ActionItemStorage.shared.getLocalActionItems(
- limit: 100,
- completed: true
- )
- } catch {
- log("TaskPrioritize: Failed to fetch reference tasks: \(error)")
- referenceTasks = []
- }
- let referenceContext = buildReferenceContext(referenceTasks)
-
- return (referenceContext, userProfile?.profileText, goals)
- }
-
- // MARK: - Context Builders
-
- private func buildReferenceContext(_ tasks: [TaskActionItem]) -> String {
- guard !tasks.isEmpty else { return "" }
-
- let completed = tasks.filter { !($0.description.isEmpty) }.prefix(50)
- guard !completed.isEmpty else { return "" }
-
- let lines = completed.map { task -> String in
- "- [completed] \(task.description)"
- }.joined(separator: "\n")
-
- return "TASKS THE USER HAS COMPLETED (for reference — do NOT rank these):\n\(lines)"
- }
-}
-
-// MARK: - Response Models
-
-private struct ReRankingResponse: Codable {
- let rerankedTasks: [ReRankedTask]
- let reasoning: String
-
- struct ReRankedTask: Codable {
- let taskId: String
- let newPosition: Int
-
- enum CodingKeys: String, CodingKey {
- case taskId = "task_id"
- case newPosition = "new_position"
- }
- }
-
- enum CodingKeys: String, CodingKey {
- case rerankedTasks = "reranked_tasks"
- case reasoning
- }
}
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift
new file mode 100644
index 0000000000..1d473ba4c8
--- /dev/null
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift
@@ -0,0 +1,592 @@
+import Foundation
+
+/// WebSocket client for desktop proactive AI via /v4/listen.
+/// Sends typed JSON messages (screen_frame, etc.) and routes typed responses
+/// (focus_result, etc.) back to callers via async continuations.
+///
+/// This is the Phase 2 replacement for direct GeminiClient calls — all LLM
+/// processing happens server-side; the client just sends screenshots and
+/// receives structured results.
+class BackendProactiveService {
+
+ // MARK: - Types
+
+ enum ServiceError: LocalizedError {
+ case missingAPIURL
+ case authFailed(String)
+ case notConnected
+ case timeout
+ case serverError(String)
+
+ var errorDescription: String? {
+ switch self {
+ case .missingAPIURL: return "OMI_API_URL not set"
+ case .authFailed(let reason): return "Auth failed: \(reason)"
+ case .notConnected: return "Backend WebSocket not connected"
+ case .timeout: return "Request timed out"
+ case .serverError(let msg): return "Server error: \(msg)"
+ }
+ }
+ }
+
+ // MARK: - Properties
+
+ private var webSocketTask: URLSessionWebSocketTask?
+ private var urlSession: URLSession?
+ private(set) var isConnected = false
+ private var shouldReconnect = false
+ private var reconnectAttempts = 0
+ private let maxReconnectAttempts = 10
+ private var reconnectTask: Task?
+
+ // Keepalive
+ private var keepaliveTask: Task?
+ private let keepaliveInterval: TimeInterval = 30.0
+
+ // Pending continuations keyed by frame_id (vision handlers)
+ private var pendingFocusRequests: [String: CheckedContinuation] = [:]
+ private var pendingTasksRequests: [String: CheckedContinuation] = [:]
+ private var pendingMemoriesRequests: [String: CheckedContinuation] = [:]
+ private var pendingAdviceRequests: [String: CheckedContinuation] = [:]
+
+ // Pending continuations for text-only handlers (one outstanding per type)
+ private var pendingLiveNote: CheckedContinuation?
+ private var pendingProfile: CheckedContinuation?
+ private var pendingRerank: CheckedContinuation?
+ private var pendingDedup: CheckedContinuation?
+
+ private let requestLock = NSLock()
+ private let requestTimeout: TimeInterval = 30.0
+ private let textRequestTimeout: TimeInterval = 60.0
+
+ // MARK: - Connection
+
+ func connect() {
+ shouldReconnect = true
+ reconnectAttempts = 0
+ startConnect()
+ }
+
+ func disconnect() {
+ shouldReconnect = false
+ reconnectTask?.cancel()
+ reconnectTask = nil
+ keepaliveTask?.cancel()
+ keepaliveTask = nil
+
+ isConnected = false
+ webSocketTask?.cancel(with: .normalClosure, reason: nil)
+ webSocketTask = nil
+ urlSession?.invalidateAndCancel()
+ urlSession = nil
+
+ cancelAllPending(error: ServiceError.notConnected)
+ log("BackendProactiveService: Disconnected")
+ }
+
+ // MARK: - Vision Handlers (screen_frame)
+
+ /// Send a screen_frame for focus analysis and wait for the focus_result response.
+ func analyzeFocus(imageBase64: String, appName: String, windowTitle: String) async throws -> ScreenAnalysis {
+ guard isConnected else { throw ServiceError.notConnected }
+ let frameId = UUID().uuidString
+ let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["focus"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle)
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingFocusRequests[frameId] = continuation
+ requestLock.unlock()
+ sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout,
+ remove: { self.pendingFocusRequests.removeValue(forKey: $0) })
+ }
+ }
+
+ /// Send a screen_frame for task extraction.
+ func extractTasks(imageBase64: String, appName: String, windowTitle: String) async throws -> TasksExtractedResult {
+ guard isConnected else { throw ServiceError.notConnected }
+ let frameId = UUID().uuidString
+ let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["tasks"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle)
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingTasksRequests[frameId] = continuation
+ requestLock.unlock()
+ sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout,
+ remove: { self.pendingTasksRequests.removeValue(forKey: $0) })
+ }
+ }
+
+ /// Send a screen_frame for memory extraction.
+ func extractMemories(imageBase64: String, appName: String, windowTitle: String) async throws -> MemoriesExtractedResult {
+ guard isConnected else { throw ServiceError.notConnected }
+ let frameId = UUID().uuidString
+ let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["memories"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle)
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingMemoriesRequests[frameId] = continuation
+ requestLock.unlock()
+ sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout,
+ remove: { self.pendingMemoriesRequests.removeValue(forKey: $0) })
+ }
+ }
+
+ /// Send a screen_frame for advice generation.
+ func generateAdvice(imageBase64: String, appName: String, windowTitle: String) async throws -> AdviceExtractedResult {
+ guard isConnected else { throw ServiceError.notConnected }
+ let frameId = UUID().uuidString
+ let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["advice"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle)
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingAdviceRequests[frameId] = continuation
+ requestLock.unlock()
+ sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout,
+ remove: { self.pendingAdviceRequests.removeValue(forKey: $0) })
+ }
+ }
+
+ // MARK: - Text-Only Handlers
+
+ /// Send transcript text for live note generation.
+ func generateLiveNote(text: String, sessionContext: String = "") async throws -> String {
+ guard isConnected else { throw ServiceError.notConnected }
+ let jsonString = try buildJSON(["type": "live_notes_text", "text": text, "session_context": sessionContext])
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingLiveNote = continuation
+ requestLock.unlock()
+ sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout,
+ remove: { let c = self.pendingLiveNote; self.pendingLiveNote = nil; return c })
+ }
+ }
+
+ /// Request profile generation (server fetches user data from Firestore).
+ func requestProfile() async throws -> String {
+ guard isConnected else { throw ServiceError.notConnected }
+ let jsonString = try buildJSON(["type": "profile_request"])
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingProfile = continuation
+ requestLock.unlock()
+ sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout,
+ remove: { let c = self.pendingProfile; self.pendingProfile = nil; return c })
+ }
+ }
+
+ /// Request task reranking (server fetches tasks from Firestore).
+ func rerankTasks() async throws -> RerankExtractedResult {
+ guard isConnected else { throw ServiceError.notConnected }
+ let jsonString = try buildJSON(["type": "task_rerank"])
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingRerank = continuation
+ requestLock.unlock()
+ sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout,
+ remove: { let c = self.pendingRerank; self.pendingRerank = nil; return c })
+ }
+ }
+
+ /// Request task deduplication (server fetches tasks from Firestore).
+ func deduplicateTasks() async throws -> DedupExtractedResult {
+ guard isConnected else { throw ServiceError.notConnected }
+ let jsonString = try buildJSON(["type": "task_dedup"])
+
+ return try await withCheckedThrowingContinuation { continuation in
+ requestLock.lock()
+ pendingDedup = continuation
+ requestLock.unlock()
+ sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout,
+ remove: { let c = self.pendingDedup; self.pendingDedup = nil; return c })
+ }
+ }
+
+ // MARK: - Send Helpers
+
+ private func buildScreenFrameJSON(frameId: String, analyzeTypes: [String], imageBase64: String, appName: String, windowTitle: String) throws -> String {
+ try buildJSON([
+ "type": "screen_frame",
+ "frame_id": frameId,
+ "image_b64": imageBase64,
+ "app_name": appName,
+ "window_title": windowTitle,
+ "analyze": analyzeTypes,
+ ])
+ }
+
+ private func buildJSON(_ dict: [String: Any]) throws -> String {
+ let data = try JSONSerialization.data(withJSONObject: dict)
+ guard let str = String(data: data, encoding: .utf8) else {
+ throw ServiceError.serverError("Failed to encode message")
+ }
+ return str
+ }
+
+ /// Send JSON and set up timeout for frame_id-keyed continuations.
+ private func sendAndTimeout(jsonString: String, frameId: String, timeout: TimeInterval,
+ remove: @escaping (String) -> CheckedContinuation?) {
+ webSocketTask?.send(.string(jsonString)) { [weak self] error in
+ if let error = error {
+ self?.requestLock.lock()
+ let cont = remove(frameId)
+ self?.requestLock.unlock()
+ cont?.resume(throwing: error)
+ }
+ }
+
+ Task { [weak self] in
+ try? await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000))
+ self?.requestLock.lock()
+ let cont = remove(frameId)
+ self?.requestLock.unlock()
+ cont?.resume(throwing: ServiceError.timeout)
+ }
+ }
+
+ /// Send JSON and set up timeout for single-slot continuations.
+ private func sendAndTimeoutSingle(jsonString: String, timeout: TimeInterval,
+ remove: @escaping () -> CheckedContinuation?) {
+ webSocketTask?.send(.string(jsonString)) { [weak self] error in
+ if let error = error {
+ self?.requestLock.lock()
+ let cont = remove()
+ self?.requestLock.unlock()
+ cont?.resume(throwing: error)
+ }
+ }
+
+ Task { [weak self] in
+ try? await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000))
+ self?.requestLock.lock()
+ let cont = remove()
+ self?.requestLock.unlock()
+ cont?.resume(throwing: ServiceError.timeout)
+ }
+ }
+
+ // MARK: - Connection Internals
+
+ private func startConnect() {
+ guard let baseURL = Self.getBaseURL() else {
+ log("BackendProactiveService: OMI_API_URL not set")
+ return
+ }
+
+ Task {
+ do {
+ let idToken = try await AuthService.shared.getIdToken()
+ await connectWithToken(baseURL: baseURL, token: idToken)
+ } catch {
+ logError("BackendProactiveService: Failed to get ID token", error: error)
+ handleDisconnection()
+ }
+ }
+ }
+
+ private func connectWithToken(baseURL: String, token: String) async {
+ let wsURL = baseURL
+ .replacingOccurrences(of: "https://", with: "wss://")
+ .replacingOccurrences(of: "http://", with: "ws://")
+ let base = wsURL.hasSuffix("/") ? wsURL : wsURL + "/"
+
+ var components = URLComponents(string: "\(base)v4/listen")!
+ components.queryItems = [
+ URLQueryItem(name: "source", value: "desktop"),
+ URLQueryItem(name: "sample_rate", value: "16000"),
+ URLQueryItem(name: "codec", value: "pcm16"),
+ URLQueryItem(name: "channels", value: "1"),
+ URLQueryItem(name: "language", value: "en"),
+ ]
+
+ guard let url = components.url else {
+ log("BackendProactiveService: Invalid URL")
+ return
+ }
+
+ log("BackendProactiveService: Connecting to \(url.absoluteString)")
+
+ var request = URLRequest(url: url)
+ request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
+ request.timeoutInterval = 30
+
+ let configuration = URLSessionConfiguration.default
+ configuration.timeoutIntervalForResource = 0
+ urlSession = URLSession(configuration: configuration)
+ webSocketTask = urlSession?.webSocketTask(with: request)
+ webSocketTask?.resume()
+
+ receiveMessage()
+
+ DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in
+ guard let self = self, self.webSocketTask?.state == .running else {
+ self?.handleDisconnection()
+ return
+ }
+ self.isConnected = true
+ self.reconnectAttempts = 0
+ self.startKeepalive()
+ log("BackendProactiveService: Connected")
+ }
+ }
+
+ private func startKeepalive() {
+ keepaliveTask?.cancel()
+ keepaliveTask = Task { [weak self] in
+ while !Task.isCancelled {
+ try? await Task.sleep(nanoseconds: UInt64((self?.keepaliveInterval ?? 30.0) * 1_000_000_000))
+ guard !Task.isCancelled, let self = self, self.isConnected else { break }
+ self.sendKeepalive()
+ }
+ }
+ }
+
+ private func sendKeepalive() {
+ guard isConnected, let ws = webSocketTask else { return }
+ ws.send(.string("{\"type\": \"KeepAlive\"}")) { [weak self] error in
+ if let error = error {
+ logError("BackendProactiveService: Keepalive error", error: error)
+ self?.handleDisconnection()
+ }
+ }
+ }
+
+ private func handleDisconnection() {
+ guard isConnected || shouldReconnect else { return }
+
+ isConnected = false
+ keepaliveTask?.cancel()
+ keepaliveTask = nil
+ webSocketTask?.cancel(with: .goingAway, reason: nil)
+ webSocketTask = nil
+ urlSession?.invalidateAndCancel()
+ urlSession = nil
+
+ cancelAllPending(error: ServiceError.notConnected)
+
+ if shouldReconnect && reconnectAttempts < maxReconnectAttempts {
+ reconnectAttempts += 1
+ let delay = min(pow(2.0, Double(reconnectAttempts)), 32.0)
+ log("BackendProactiveService: Reconnecting in \(delay)s (attempt \(reconnectAttempts))")
+
+ reconnectTask = Task {
+ try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
+ guard !Task.isCancelled, self.shouldReconnect else { return }
+ self.startConnect()
+ }
+ } else if reconnectAttempts >= maxReconnectAttempts {
+ log("BackendProactiveService: Max reconnect attempts reached")
+ }
+ }
+
+ // MARK: - Message Handling
+
+ private func receiveMessage() {
+ webSocketTask?.receive { [weak self] result in
+ guard let self = self else { return }
+
+ switch result {
+ case .success(let message):
+ self.handleMessage(message)
+ self.receiveMessage()
+ case .failure(let error):
+ guard self.isConnected else { return }
+ logError("BackendProactiveService: Receive error", error: error)
+ self.handleDisconnection()
+ }
+ }
+ }
+
+ private func handleMessage(_ message: URLSessionWebSocketTask.Message) {
+ let text: String
+ switch message {
+ case .string(let s):
+ text = s
+ case .data(let data):
+ guard let s = String(data: data, encoding: .utf8) else { return }
+ text = s
+ @unknown default:
+ return
+ }
+
+ if text == "ping" { return }
+
+ guard let data = text.data(using: .utf8),
+ let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
+ let type = json["type"] as? String else {
+ return
+ }
+
+ switch type {
+ case "focus_result":
+ handleFocusResult(json)
+ case "tasks_extracted":
+ handleTasksExtracted(json)
+ case "memories_extracted":
+ handleMemoriesExtracted(json)
+ case "advice_extracted":
+ handleAdviceExtracted(json)
+ case "live_note":
+ handleLiveNote(json)
+ case "profile_updated":
+ handleProfileUpdated(json)
+ case "rerank_complete":
+ handleRerankComplete(json)
+ case "dedup_complete":
+ handleDedupComplete(json)
+ default:
+ break
+ }
+ }
+
+ // MARK: - Response Handlers
+
+ private func handleFocusResult(_ json: [String: Any]) {
+ guard let frameId = json["frame_id"] as? String else { return }
+ let analysis = ScreenAnalysis(
+ status: FocusStatus(rawValue: json["status"] as? String ?? "focused") ?? .focused,
+ appOrSite: json["app_or_site"] as? String ?? "",
+ description: json["description"] as? String ?? "",
+ message: json["message"] as? String
+ )
+ requestLock.lock()
+ let cont = pendingFocusRequests.removeValue(forKey: frameId)
+ requestLock.unlock()
+ cont?.resume(returning: analysis)
+ }
+
+ private func handleTasksExtracted(_ json: [String: Any]) {
+ guard let frameId = json["frame_id"] as? String else { return }
+ let tasks = (json["tasks"] as? [[String: Any]]) ?? []
+ let result = TasksExtractedResult(frameId: frameId, tasks: tasks)
+ requestLock.lock()
+ let cont = pendingTasksRequests.removeValue(forKey: frameId)
+ requestLock.unlock()
+ cont?.resume(returning: result)
+ }
+
+ private func handleMemoriesExtracted(_ json: [String: Any]) {
+ guard let frameId = json["frame_id"] as? String else { return }
+ let memories = (json["memories"] as? [[String: Any]]) ?? []
+ let result = MemoriesExtractedResult(frameId: frameId, memories: memories)
+ requestLock.lock()
+ let cont = pendingMemoriesRequests.removeValue(forKey: frameId)
+ requestLock.unlock()
+ cont?.resume(returning: result)
+ }
+
+ private func handleAdviceExtracted(_ json: [String: Any]) {
+ guard let frameId = json["frame_id"] as? String else { return }
+ let result = AdviceExtractedResult(frameId: frameId, advice: json["advice"])
+ requestLock.lock()
+ let cont = pendingAdviceRequests.removeValue(forKey: frameId)
+ requestLock.unlock()
+ cont?.resume(returning: result)
+ }
+
+ private func handleLiveNote(_ json: [String: Any]) {
+ let text = json["text"] as? String ?? ""
+ requestLock.lock()
+ let cont = pendingLiveNote
+ pendingLiveNote = nil
+ requestLock.unlock()
+ cont?.resume(returning: text)
+ }
+
+ private func handleProfileUpdated(_ json: [String: Any]) {
+ let profileText = json["profile_text"] as? String ?? ""
+ requestLock.lock()
+ let cont = pendingProfile
+ pendingProfile = nil
+ requestLock.unlock()
+ cont?.resume(returning: profileText)
+ }
+
+ private func handleRerankComplete(_ json: [String: Any]) {
+ let updatedTasks = (json["updated_tasks"] as? [[String: Any]]) ?? []
+ let result = RerankExtractedResult(updatedTasks: updatedTasks)
+ requestLock.lock()
+ let cont = pendingRerank
+ pendingRerank = nil
+ requestLock.unlock()
+ cont?.resume(returning: result)
+ }
+
+ private func handleDedupComplete(_ json: [String: Any]) {
+ let deletedIds = (json["deleted_ids"] as? [String]) ?? []
+ let reason = json["reason"] as? String ?? ""
+ let result = DedupExtractedResult(deletedIds: deletedIds, reason: reason)
+ requestLock.lock()
+ let cont = pendingDedup
+ pendingDedup = nil
+ requestLock.unlock()
+ cont?.resume(returning: result)
+ }
+
+ // MARK: - Helpers
+
+ private func cancelAllPending(error: Error) {
+ requestLock.lock()
+ let focus = pendingFocusRequests; pendingFocusRequests.removeAll()
+ let tasks = pendingTasksRequests; pendingTasksRequests.removeAll()
+ let memories = pendingMemoriesRequests; pendingMemoriesRequests.removeAll()
+ let advice = pendingAdviceRequests; pendingAdviceRequests.removeAll()
+ let liveNote = pendingLiveNote; pendingLiveNote = nil
+ let profile = pendingProfile; pendingProfile = nil
+ let rerank = pendingRerank; pendingRerank = nil
+ let dedup = pendingDedup; pendingDedup = nil
+ requestLock.unlock()
+
+ for (_, c) in focus { c.resume(throwing: error) }
+ for (_, c) in tasks { c.resume(throwing: error) }
+ for (_, c) in memories { c.resume(throwing: error) }
+ for (_, c) in advice { c.resume(throwing: error) }
+ liveNote?.resume(throwing: error)
+ profile?.resume(throwing: error)
+ rerank?.resume(throwing: error)
+ dedup?.resume(throwing: error)
+ }
+
+ private static func getBaseURL() -> String? {
+ if let cString = getenv("OMI_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty {
+ return url
+ }
+ if let envURL = ProcessInfo.processInfo.environment["OMI_API_URL"], !envURL.isEmpty {
+ return envURL
+ }
+ return nil
+ }
+}
+
+// MARK: - Result Types
+
+/// Tasks extracted from a screen_frame analysis.
+struct TasksExtractedResult {
+ let frameId: String
+ let tasks: [[String: Any]] // Raw task dicts from backend
+}
+
+/// Memories extracted from a screen_frame analysis.
+struct MemoriesExtractedResult {
+ let frameId: String
+ let memories: [[String: Any]] // Raw memory dicts from backend
+}
+
+/// Advice extracted from a screen_frame analysis.
+struct AdviceExtractedResult {
+ let frameId: String
+ let advice: Any? // Raw advice from backend (dict or null)
+}
+
+/// Task reranking result.
+struct RerankExtractedResult {
+ let updatedTasks: [[String: Any]] // [{id, new_position}, ...]
+}
+
+/// Task deduplication result.
+struct DedupExtractedResult {
+ let deletedIds: [String]
+ let reason: String
+}
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift
index bd00065057..0a5c31a0b4 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift
@@ -14,6 +14,7 @@ public class ProactiveAssistantsPlugin: NSObject {
private var screenCaptureService: ScreenCaptureService?
private var windowMonitor: WindowMonitor?
+ private var backendProactiveService: BackendProactiveService?
private var focusAssistant: FocusAssistant?
/// Public read-only accessor for memory diagnostics
@@ -319,61 +320,64 @@ public class ProactiveAssistantsPlugin: NSObject {
// Initialize services
screenCaptureService = ScreenCaptureService()
- do {
- focusAssistant = try FocusAssistant(
- onAlert: { [weak self] message in
- self?.sendEvent(type: "alert", data: ["message": message])
- },
- onStatusChange: { [weak self] status in
- Task { @MainActor in
- self?.lastStatus = status
- self?.sendEvent(type: "statusChange", data: ["status": status.rawValue])
- }
- },
- onRefocus: {
- Task { @MainActor in
- OverlayService.shared.showGlowAroundActiveWindow(colorMode: .focused)
- }
- },
- onDistraction: {
- Task { @MainActor in
- OverlayService.shared.showGlowAroundActiveWindow(colorMode: .distracted)
- }
+ // Start backend proactive AI WebSocket (Phase 2 — server-side LLM)
+ let proactiveService = BackendProactiveService()
+ proactiveService.connect()
+ backendProactiveService = proactiveService
+
+ focusAssistant = FocusAssistant(
+ backendService: proactiveService,
+ onAlert: { [weak self] message in
+ self?.sendEvent(type: "alert", data: ["message": message])
+ },
+ onStatusChange: { [weak self] status in
+ Task { @MainActor in
+ self?.lastStatus = status
+ self?.sendEvent(type: "statusChange", data: ["status": status.rawValue])
+ }
+ },
+ onRefocus: {
+ Task { @MainActor in
+ OverlayService.shared.showGlowAroundActiveWindow(colorMode: .focused)
+ }
+ },
+ onDistraction: {
+ Task { @MainActor in
+ OverlayService.shared.showGlowAroundActiveWindow(colorMode: .distracted)
}
- )
-
- if let focus = focusAssistant {
- AssistantCoordinator.shared.register(focus)
}
+ )
- taskAssistant = try TaskAssistant()
+ if let focus = focusAssistant {
+ AssistantCoordinator.shared.register(focus)
+ }
- if let task = taskAssistant {
- AssistantCoordinator.shared.register(task)
- }
+ taskAssistant = TaskAssistant(backendService: proactiveService)
- Task { await TaskDeduplicationService.shared.start() }
- Task { await TaskPrioritizationService.shared.start() }
- Task { await TaskPromotionService.shared.start() }
+ if let task = taskAssistant {
+ AssistantCoordinator.shared.register(task)
+ }
- adviceAssistant = try AdviceAssistant()
+ // Configure text-only services with backend service
+ Task { await TaskDeduplicationService.shared.configure(backendService: proactiveService) }
+ Task { await TaskPrioritizationService.shared.configure(backendService: proactiveService) }
+ Task { await AIUserProfileService.shared.configure(backendService: proactiveService) }
+ Task { await LiveNotesMonitor.shared.configure(backendService: proactiveService) }
- if let advice = adviceAssistant {
- AssistantCoordinator.shared.register(advice)
- }
+ Task { await TaskDeduplicationService.shared.start() }
+ Task { await TaskPrioritizationService.shared.start() }
+ Task { await TaskPromotionService.shared.start() }
- memoryAssistant = try MemoryAssistant()
+ adviceAssistant = AdviceAssistant(backendService: proactiveService)
- if let memory = memoryAssistant {
- AssistantCoordinator.shared.register(memory)
- }
+ if let advice = adviceAssistant {
+ AssistantCoordinator.shared.register(advice)
+ }
- } catch {
- log("ProactiveAssistantsPlugin: Failed to initialize assistants: \(error.localizedDescription)")
- logError("ProactiveAssistantsPlugin: Assistant initialization failed", error: error)
- isStartingMonitoring = false
- completion(false, error.localizedDescription)
- return
+ memoryAssistant = MemoryAssistant(backendService: proactiveService)
+
+ if let memory = memoryAssistant {
+ AssistantCoordinator.shared.register(memory)
}
// Get initial app state
@@ -459,6 +463,8 @@ public class ProactiveAssistantsPlugin: NSObject {
}
}
+ backendProactiveService?.disconnect()
+ backendProactiveService = nil
focusAssistant = nil
taskAssistant = nil
adviceAssistant = nil
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift
index ec36f82ece..111615bf75 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift
@@ -34,7 +34,7 @@ extension AIUserProfileRecord: TableDocumented {
actor AIUserProfileService {
static let shared = AIUserProfileService()
- private let model = "gemini-pro-latest"
+ private var backendService: BackendProactiveService?
private let maxProfileLength = 10000
/// Whether profile generation is currently in progress
@@ -48,6 +48,11 @@ actor AIUserProfileService {
_dbQueue = nil
}
+ /// Set the backend service for Phase 2 server-side profile generation.
+ func configure(backendService: BackendProactiveService) {
+ self.backendService = backendService
+ }
+
// MARK: - Database Access
private func ensureDB() async throws -> DatabasePool {
@@ -211,314 +216,42 @@ actor AIUserProfileService {
}) ?? []
}
- /// Generate a new AI user profile from all available data sources
+ /// Generate a new AI user profile via backend WebSocket.
+ /// The backend fetches all user data from Firestore and generates the profile server-side.
func generateProfile() async throws -> AIUserProfileRecord {
guard !isGenerating else {
throw ProfileError.alreadyGenerating
}
+ guard let service = backendService else {
+ throw ProfileError.databaseNotAvailable
+ }
isGenerating = true
defer { isGenerating = false }
- log("AIUserProfileService: Starting profile generation")
-
- // 1. Fetch all data sources in parallel
- let (memories, tasks, goals, conversations, messages) = await fetchDataSources()
-
- // 2. Count total data items
- let dataSourcesUsed = memories.count + tasks.count + goals.count + conversations.count + messages.count
- log("AIUserProfileService: Fetched \(dataSourcesUsed) data items (memories=\(memories.count), tasks=\(tasks.count), goals=\(goals.count), convos=\(conversations.count), messages=\(messages.count))")
-
- guard dataSourcesUsed > 0 else {
- throw ProfileError.insufficientData
- }
-
- // 3. Build prompt
- let prompt = buildPrompt(memories: memories, tasks: tasks, goals: goals, conversations: conversations, messages: messages)
-
- // 4. Call Gemini
- let gemini = try GeminiClient(model: model)
- let systemPrompt = """
- You are generating a structured user profile that will be injected as context into AI pipelines \
- (task extraction, goal extraction, memory extraction) that analyze the user's screen and audio activity.
-
- OUTPUT FORMAT:
- - A flat list of factual statements, one per line, prefixed with "- "
- - Each statement must be a concrete fact directly supported by the provided data
- - No prose, no paragraphs, no headers, no markdown formatting
- - No adjectives like "passionate", "dedicated", "impressive"
- - Write in third person ("User works at...", not "You work at...")
-
- WHAT TO INCLUDE (only if clearly supported by the data):
- - Full name, role, company, industry
- - Current projects and what tools/apps they use for each
- - Key people they interact with (names, roles, relationship)
- - Active goals and their progress
- - Recurring meetings, deadlines, routines
- - Communication platforms they use (Slack, email, iMessage, etc.)
- - Technical stack, programming languages, frameworks
- - Topics they frequently discuss or research
- - Pending tasks and commitments to others
- - Time zone, work schedule patterns
-
- CRITICAL RULES:
- - ONLY include facts that are directly evidenced in the provided data
- - If a category has no supporting data, skip it entirely — do not guess or infer
- - Do NOT hallucinate names, roles, companies, or relationships not present in the data
- - Do NOT add personality descriptions or subjective assessments
- - When uncertain, omit rather than speculate
- - NEVER fabricate email addresses, phone numbers, URLs, or contact information
- - The provided data contains NO email addresses — do not invent any
- - If you cannot find a piece of information verbatim in the data, do not include it
-
- The output MUST be under 2000 characters total.
- """
-
- let stageOneText = try await gemini.sendTextRequest(prompt: prompt, systemPrompt: systemPrompt)
- log("AIUserProfileService: Stage 1 complete (\(stageOneText.count) chars)")
-
- // 5. Stage 2 — Consolidate with past profiles for holistic view
- let pastProfiles = await getAllProfiles(limit: 5)
- let finalText: String
- if pastProfiles.isEmpty {
- finalText = stageOneText
- } else {
- let consolidationPrompt = buildConsolidationPrompt(
- newProfile: stageOneText,
- pastProfiles: pastProfiles
- )
- let consolidationSystemPrompt = """
- You are merging a newly generated user profile with historical profiles to create \
- one holistic, up-to-date user profile. This profile is injected as context into AI pipelines \
- (task extraction, goal extraction, memory extraction) that analyze the user's screen and audio activity.
-
- OUTPUT FORMAT:
- - A flat list of factual statements, one per line, prefixed with "- "
- - Each statement must be a concrete fact
- - No prose, no paragraphs, no headers, no markdown formatting
- - No adjectives or subjective assessments
- - Write in third person
-
- MERGE RULES:
- - The NEW profile reflects today's data and takes priority for current state
- - Past profiles provide historical context — retain facts that are still relevant
- - If a fact from the past contradicts the new profile, use the new one
- - Remove outdated information (completed tasks, past deadlines, old routines)
- - Keep stable facts (name, role, company, key relationships, tech stack)
- - Accumulate knowledge: if past profiles mention people, projects, or patterns \
- not in today's data, keep them if they seem ongoing
- - Do NOT hallucinate — only include facts present in the provided profiles
- - Do NOT add commentary about changes or evolution over time
-
- The output MUST be under 2000 characters total.
- """
- finalText = try await gemini.sendTextRequest(
- prompt: consolidationPrompt,
- systemPrompt: consolidationSystemPrompt
- )
- log("AIUserProfileService: Stage 2 consolidation complete (\(finalText.count) chars)")
- }
+ log("AIUserProfileService: Requesting server-side profile generation")
- // 6. Truncate if needed
- let truncated = String(finalText.prefix(maxProfileLength))
+ let profileText = try await service.requestProfile()
+ let truncated = String(profileText.prefix(maxProfileLength))
let generatedAt = Date()
- // 6. Save to database
+ log("AIUserProfileService: Received profile from backend (\(truncated.count) chars)")
+
+ // Save to local database
let db = try await ensureDB()
let record = AIUserProfileRecord(
profileText: truncated,
- dataSourcesUsed: dataSourcesUsed,
- backendSynced: false,
+ dataSourcesUsed: 0,
+ backendSynced: true, // Backend already has it
generatedAt: generatedAt
)
try await db.write { database in
try record.insert(database)
}
- // 7. Sync to backend (fire-and-forget)
- let recordId = record.id
- Task {
- do {
- try await APIClient.shared.syncAIUserProfile(
- profileText: truncated,
- generatedAt: generatedAt,
- dataSourcesUsed: dataSourcesUsed
- )
- // Mark as synced
- if let id = recordId, let db = try? await self.ensureDB() {
- _ = try? await db.write { database in
- try database.execute(
- sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?",
- arguments: [id]
- )
- }
- }
- log("AIUserProfileService: Synced profile to backend")
- } catch {
- log("AIUserProfileService: Failed to sync profile to backend: \(error.localizedDescription)")
- }
- }
-
- log("AIUserProfileService: Profile generated successfully (\(truncated.count) chars, \(dataSourcesUsed) data items)")
+ log("AIUserProfileService: Profile saved to local DB")
return record
}
- // MARK: - Data Fetching
-
- private func fetchDataSources() async -> (
- memories: [String],
- tasks: [String],
- goals: [String],
- conversations: [String],
- messages: [String]
- ) {
- async let memoriesTask = fetchMemories()
- async let tasksTask = fetchTasks()
- async let goalsTask = fetchGoals()
- async let conversationsTask = fetchConversations()
- async let messagesTask = fetchMessages()
-
- let memories = await memoriesTask
- let tasks = await tasksTask
- let goals = await goalsTask
- let conversations = await conversationsTask
- let messages = await messagesTask
-
- return (memories, tasks, goals, conversations, messages)
- }
-
- private func fetchMemories() async -> [String] {
- do {
- let memories = try await APIClient.shared.getMemories(limit: 100)
- return memories.map { "[\($0.category.rawValue)] \($0.content)" }
- } catch {
- log("AIUserProfileService: Failed to fetch memories: \(error.localizedDescription)")
- return []
- }
- }
-
- private func fetchTasks() async -> [String] {
- do {
- let response = try await APIClient.shared.getActionItems(limit: 50)
- return response.items.map { item in
- let status = item.completed ? "done" : "todo"
- let priority = item.priority ?? "medium"
- return "[\(status)/\(priority)] \(item.description)"
- }
- } catch {
- log("AIUserProfileService: Failed to fetch tasks: \(error.localizedDescription)")
- return []
- }
- }
-
- private func fetchGoals() async -> [String] {
- do {
- let goals = try await APIClient.shared.getGoals()
- return goals.filter { $0.isActive }.map { goal in
- let progress = goal.targetValue > 0 ? Int((goal.currentValue / goal.targetValue) * 100) : 0
- return "\(goal.title) (\(progress)% complete)"
- }
- } catch {
- log("AIUserProfileService: Failed to fetch goals: \(error.localizedDescription)")
- return []
- }
- }
-
- private func fetchConversations() async -> [String] {
- do {
- let sevenDaysAgo = Calendar.current.date(byAdding: .day, value: -7, to: Date())
- let conversations = try await APIClient.shared.getConversations(
- limit: 20,
- startDate: sevenDaysAgo
- )
- return conversations.compactMap { convo in
- let title = convo.structured.title
- let summary = convo.structured.overview
- guard !title.isEmpty else { return nil }
- return "\(title): \(summary)"
- }
- } catch {
- log("AIUserProfileService: Failed to fetch conversations: \(error.localizedDescription)")
- return []
- }
- }
-
- private func fetchMessages() async -> [String] {
- do {
- let messages = try await APIClient.shared.getMessages(limit: 30)
- return messages.map { "[\($0.sender)] \($0.text)" }
- } catch {
- log("AIUserProfileService: Failed to fetch messages: \(error.localizedDescription)")
- return []
- }
- }
-
- // MARK: - Prompt Building
-
- private func buildPrompt(
- memories: [String],
- tasks: [String],
- goals: [String],
- conversations: [String],
- messages: [String]
- ) -> String {
- var sections: [String] = []
-
- if !memories.isEmpty {
- sections.append("## Memories about the user\n\(memories.joined(separator: "\n"))")
- }
-
- if !tasks.isEmpty {
- sections.append("## Recent tasks\n\(tasks.joined(separator: "\n"))")
- }
-
- if !goals.isEmpty {
- sections.append("## Active goals\n\(goals.joined(separator: "\n"))")
- }
-
- if !conversations.isEmpty {
- sections.append("## Recent conversations (past 7 days)\n\(conversations.joined(separator: "\n"))")
- }
-
- if !messages.isEmpty {
- sections.append("## Recent AI chat messages\n\(messages.joined(separator: "\n"))")
- }
-
- return """
- Generate a factual user profile from the following data. \
- Output a flat list of concrete facts (one per line, prefixed with "- "). \
- This profile will be used as context for AI pipelines that analyze the user's screen and audio activity \
- to extract tasks, goals, and memories. Focus on facts that help identify who is who, what projects are active, \
- and what the user's current priorities are. Under 2000 characters.
-
- \(sections.joined(separator: "\n\n"))
- """
- }
-
- private func buildConsolidationPrompt(
- newProfile: String,
- pastProfiles: [AIUserProfileRecord]
- ) -> String {
- let dateFormatter = DateFormatter()
- dateFormatter.dateStyle = .medium
- dateFormatter.timeStyle = .none
-
- var pastSection = ""
- for profile in pastProfiles {
- let dateStr = dateFormatter.string(from: profile.generatedAt)
- pastSection += "--- Profile from \(dateStr) ---\n\(profile.profileText)\n\n"
- }
-
- return """
- Merge the following into one holistic user profile. Under 2000 characters.
-
- === NEW PROFILE (generated today from latest data) ===
- \(newProfile)
-
- === PAST PROFILES (oldest to newest, up to 5) ===
- \(pastSection)
- """
- }
-
// MARK: - Errors
enum ProfileError: LocalizedError {
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift
index 925c73f312..a87b4582b5 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift
@@ -441,17 +441,9 @@ struct AdviceTestRunnerView: View {
adviceAssistant = existing
log("AdviceTestRunner: Using existing AdviceAssistant from coordinator")
} else {
- do {
- adviceAssistant = try AdviceAssistant()
- log("AdviceTestRunner: Created fresh AdviceAssistant instance")
- } catch {
- log("AdviceTestRunner: ERROR - Failed to create AdviceAssistant: \(error)")
- await MainActor.run {
- statusMessage = "Failed to create Advice Assistant: \(error.localizedDescription)"
- isRunning = false
- }
- return
- }
+ let service = BackendProactiveService(); service.connect()
+ adviceAssistant = AdviceAssistant(backendService: service)
+ log("AdviceTestRunner: Created fresh AdviceAssistant instance")
}
// Get excluded apps
@@ -647,12 +639,8 @@ enum AdviceTestRunner {
if let existing = coordAssistant as? AdviceAssistant {
adviceAssistant = existing
} else {
- do {
- adviceAssistant = try AdviceAssistant()
- } catch {
- log("AdviceTestCLI: ERROR — Failed to create AdviceAssistant: \(error)")
- return
- }
+ let service = BackendProactiveService(); service.connect()
+ adviceAssistant = AdviceAssistant(backendService: service)
}
// Get excluded apps
diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift
index ca5c21034a..0ae19fd68e 100644
--- a/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift
+++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift
@@ -637,12 +637,9 @@ enum FocusTestRunner {
if let existing = coordAssistant as? FocusAssistant {
focusAssistant = existing
} else {
- do {
- focusAssistant = try FocusAssistant()
- } catch {
- log("FocusTestCLI: ERROR — Failed to create FocusAssistant: \(error)")
- return
- }
+ let service = BackendProactiveService()
+ service.connect()
+ focusAssistant = FocusAssistant(backendService: service)
}
// Get excluded apps
diff --git a/desktop/Desktop/Sources/Stores/TasksStore.swift b/desktop/Desktop/Sources/Stores/TasksStore.swift
index d9fba16493..b5252c600f 100644
--- a/desktop/Desktop/Sources/Stores/TasksStore.swift
+++ b/desktop/Desktop/Sources/Stores/TasksStore.swift
@@ -448,8 +448,6 @@ class TasksStore: ObservableObject {
// Then retry pushing any locally-created tasks that failed to sync
Task {
await performFullSyncIfNeeded()
- await migrateAITasksToStagedIfNeeded()
- await migrateConversationItemsToStagedIfNeeded()
await retryUnsyncedItems()
}
// Backfill relevance scores for unscored tasks (independent of full sync)
@@ -808,69 +806,6 @@ class TasksStore: ObservableObject {
}
}
- /// In-memory guard to prevent duplicate migration calls within the same app session
- private static var isMigrating = false
-
- /// One-time migration: tell backend to move excess AI tasks to staged_tasks subcollection.
- /// The SQLite migration handles local data; this handles Firestore.
- /// Sets the flag optimistically before the request to avoid retry loops on timeout.
- private func migrateAITasksToStagedIfNeeded() async {
- let userId = UserDefaults.standard.string(forKey: "auth_userId") ?? "unknown"
- let migrationKey = "stagedTasksMigrationCompleted_v4_\(userId)"
-
- guard !UserDefaults.standard.bool(forKey: migrationKey) else {
- log("TasksStore: Staged tasks migration already completed for user \(userId)")
- return
- }
-
- // In-memory guard: loadTasks() can be called from multiple pages
- guard !Self.isMigrating else {
- log("TasksStore: Staged tasks migration already in progress, skipping")
- return
- }
- Self.isMigrating = true
-
- // Set flag optimistically — the migration is idempotent and safe to skip on re-run.
- // This prevents infinite retry loops when the backend succeeds but the client times out.
- UserDefaults.standard.set(true, forKey: migrationKey)
-
- log("TasksStore: Starting staged tasks backend migration for user \(userId)")
-
- do {
- try await APIClient.shared.migrateStagedTasks()
- log("TasksStore: Staged tasks backend migration completed")
- } catch {
- log("TasksStore: Staged tasks backend migration fired (may complete in background): \(error.localizedDescription)")
- }
- Self.isMigrating = false
- }
-
- /// One-time migration of conversation-extracted action items (no source field) to staged_tasks.
- /// These were created by the old save_action_items path that bypassed the staging pipeline.
- private func migrateConversationItemsToStagedIfNeeded() async {
- let userId = UserDefaults.standard.string(forKey: "auth_userId") ?? "unknown"
- let migrationKey = "conversationItemsMigrationCompleted_v4_\(userId)"
-
- guard !UserDefaults.standard.bool(forKey: migrationKey) else { return }
-
- UserDefaults.standard.set(true, forKey: migrationKey)
- log("TasksStore: Starting conversation items migration for user \(userId)")
-
- do {
- try await APIClient.shared.migrateConversationItemsToStaged()
- log("TasksStore: Conversation items migration completed, resetting full sync to clean up local SQLite")
-
- // Reset full sync flag so it re-runs and marks migrated items as staged locally
- let syncKey = "tasksFullSyncCompleted_v9_\(userId)"
- UserDefaults.standard.set(false, forKey: syncKey)
-
- // Run full sync now to clean up local SQLite
- await performFullSyncIfNeeded()
- } catch {
- log("TasksStore: Conversation items migration fired (may complete in background): \(error.localizedDescription)")
- }
- }
-
/// Retry syncing locally-created tasks that failed to push to the backend.
/// These are records with backendSynced=false and no backendId — the API call
/// failed during extraction and there was no retry mechanism.
diff --git a/desktop/Desktop/Sources/TranscriptionRetryService.swift b/desktop/Desktop/Sources/TranscriptionRetryService.swift
index ead4ee2ee5..913c6fd945 100644
--- a/desktop/Desktop/Sources/TranscriptionRetryService.swift
+++ b/desktop/Desktop/Sources/TranscriptionRetryService.swift
@@ -266,9 +266,7 @@ class TranscriptionRetryService {
startedAt: session.startedAt,
finishedAt: session.finishedAt ?? Date(),
source: source,
- language: session.language,
- timezone: session.timezone,
- inputDeviceName: session.inputDeviceName
+ language: session.language
)
log("TranscriptionRetryService: Session \(sessionId) uploaded successfully (backendId: \(response.id))")
diff --git a/desktop/dev.sh b/desktop/dev.sh
index 44fd56a73e..a9739facd7 100755
--- a/desktop/dev.sh
+++ b/desktop/dev.sh
@@ -85,8 +85,12 @@ cp Desktop/Info.plist "$APP_BUNDLE/Contents/Info.plist"
/usr/libexec/PlistBuddy -c "Set :CFBundleDisplayName $APP_NAME" "$APP_BUNDLE/Contents/Info.plist"
/usr/libexec/PlistBuddy -c "Set :CFBundleURLTypes:0:CFBundleURLSchemes:0 omi-computer-dev" "$APP_BUNDLE/Contents/Info.plist"
-# Copy GoogleService-Info.plist for Firebase
-cp Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/"
+# Copy GoogleService-Info.plist for Firebase (dev version for com.omi.desktop-dev)
+if [ -f "Desktop/Sources/GoogleService-Info-Dev.plist" ]; then
+ cp -f Desktop/Sources/GoogleService-Info-Dev.plist "$APP_BUNDLE/Contents/Resources/GoogleService-Info.plist"
+else
+ cp -f Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/"
+fi
# Copy resource bundle (contains app assets like herologo.png, omi-with-rope-no-padding.webp, etc.)
SWIFT_BUILD_DIR="Desktop/.build/debug"
diff --git a/desktop/reset-and-run.sh b/desktop/reset-and-run.sh
index 1590b862fd..c52c28a10d 100755
--- a/desktop/reset-and-run.sh
+++ b/desktop/reset-and-run.sh
@@ -381,8 +381,12 @@ cp Desktop/Info.plist "$APP_BUNDLE/Contents/Info.plist"
/usr/libexec/PlistBuddy -c "Set :CFBundleDisplayName $APP_NAME" "$APP_BUNDLE/Contents/Info.plist"
/usr/libexec/PlistBuddy -c "Set :CFBundleURLTypes:0:CFBundleURLSchemes:0 omi-computer-dev" "$APP_BUNDLE/Contents/Info.plist"
-# Copy GoogleService-Info.plist for Firebase
-cp Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/"
+# Copy GoogleService-Info.plist for Firebase (dev version for com.omi.desktop-dev)
+if [ -f "Desktop/Sources/GoogleService-Info-Dev.plist" ]; then
+ cp -f Desktop/Sources/GoogleService-Info-Dev.plist "$APP_BUNDLE/Contents/Resources/GoogleService-Info.plist"
+else
+ cp -f Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/"
+fi
# Copy .env.app (app runtime secrets only) and add API URL
if [ -f ".env.app" ]; then