Skip to content

Commit 6ee5c90

Browse files
committed
fix (memory): onboarding memory fixes
1 parent 3e3736c commit 6ee5c90

File tree

5 files changed

+185
-49
lines changed

5 files changed

+185
-49
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { NextResponse } from "next/server"
2+
import { withAuth } from "@lib/api-utils"
3+
4+
const appServerUrl =
5+
process.env.NEXT_PUBLIC_ENVIRONMENT === "selfhost"
6+
? process.env.INTERNAL_APP_SERVER_URL
7+
: process.env.NEXT_PUBLIC_APP_SERVER_URL
8+
9+
export const POST = withAuth(async function POST(request, { authHeader }) {
10+
try {
11+
const response = await fetch(
12+
`${appServerUrl}/testing/reprocess-onboarding`,
13+
{
14+
method: "POST",
15+
headers: { "Content-Type": "application/json", ...authHeader }
16+
}
17+
)
18+
19+
const data = await response.json()
20+
if (!response.ok) {
21+
throw new Error(data.detail || "Failed to trigger reprocessing.")
22+
}
23+
return NextResponse.json(data)
24+
} catch (error) {
25+
console.error("API Error in /testing/reprocess-onboarding:", error)
26+
return NextResponse.json({ detail: error.message }, { status: 500 })
27+
}
28+
})

src/client/app/settings/page.js

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import {
1515
IconKeyboard,
1616
IconFlask,
1717
IconMapPin,
18-
IconBrandWhatsapp
18+
IconBrandWhatsapp,
19+
IconRefresh
1920
} from "@tabler/icons-react"
2021
import { useState, useEffect, useCallback } from "react"
2122
import { Tooltip } from "react-tooltip"
@@ -371,6 +372,7 @@ const TestingTools = () => {
371372
'{\n "subject": "Project Alpha Kick-off",\n "body": "Hi team, let\'s schedule a meeting for next Tuesday to discuss the Project Alpha kick-off. John, please prepare the presentation."\n}'
372373
)
373374
const [isSubmitting, setIsSubmitting] = useState(false)
375+
const [isReprocessing, setIsReprocessing] = useState(false)
374376

375377
const handleSubmit = async (e) => {
376378
e.preventDefault()
@@ -420,6 +422,29 @@ const TestingTools = () => {
420422
)
421423
}
422424
}
425+
426+
const handleReprocessOnboarding = async () => {
427+
setIsReprocessing(true)
428+
const toastId = toast.loading(
429+
"Queueing onboarding data for memory reprocessing..."
430+
)
431+
try {
432+
const response = await fetch("/api/testing/reprocess-onboarding", {
433+
method: "POST"
434+
})
435+
const result = await response.json()
436+
if (!response.ok) {
437+
throw new Error(
438+
result.detail || "Failed to trigger reprocessing."
439+
)
440+
}
441+
toast.success(result.message, { id: toastId })
442+
} catch (error) {
443+
toast.error(`Error: ${error.message}`, { id: toastId })
444+
} finally {
445+
setIsReprocessing(false)
446+
}
447+
}
423448
const [isTriggeringScheduler, setIsTriggeringScheduler] = useState(false)
424449
const [isTriggeringPoller, setIsTriggeringPoller] = useState(false)
425450

@@ -597,6 +622,32 @@ const TestingTools = () => {
597622
</div>
598623
</form>
599624
</div>
625+
{/* Reprocess Onboarding Data */}
626+
<div className="bg-neutral-900/50 p-6 rounded-2xl border border-neutral-800 mt-6">
627+
<h3 className="font-semibold text-lg text-white mb-2">
628+
Reprocess Onboarding Data
629+
</h3>
630+
<p className="text-gray-400 text-sm mb-4">
631+
Manually trigger the Celery worker to process your saved
632+
onboarding answers and add them to your long-term memory.
633+
This is useful for testing memory functions without
634+
re-onboarding.
635+
</p>
636+
<div className="flex justify-end">
637+
<button
638+
onClick={handleReprocessOnboarding}
639+
disabled={isReprocessing}
640+
className="flex items-center py-2 px-4 rounded-md bg-purple-600 hover:bg-purple-500 text-white font-medium transition-colors disabled:opacity-50"
641+
>
642+
{isReprocessing ? (
643+
<IconLoader className="w-5 h-5 animate-spin" />
644+
) : (
645+
<IconRefresh className="w-5 h-5" />
646+
)}{" "}
647+
Run Reprocessing
648+
</button>
649+
</div>
650+
</div>
600651
{/* WhatsApp Test Tools */}
601652
<div className="bg-neutral-900/50 p-6 rounded-2xl border border-neutral-800 mt-6">
602653
<h3 className="font-semibold text-lg text-white mb-2">

src/server/main/testing/routes.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List
55

66
from fastapi import APIRouter, Depends, HTTPException, status, Body
7+
from fastapi.responses import JSONResponse
78

89
from main.config import ENVIRONMENT
910
from main.dependencies import auth_helper
@@ -13,9 +14,11 @@
1314
from workers.tasks import (cud_memory_task, run_due_tasks,
1415
schedule_trigger_polling)
1516

16-
from .models import ContextInjectionRequest, WhatsAppTestRequest, TestNotificationRequest
17+
from .models import WhatsAppTestRequest, TestNotificationRequest
18+
1719

1820
logger = logging.getLogger(__name__)
21+
from main.dependencies import mongo_manager
1922
router = APIRouter(
2023
prefix="/testing",
2124
tags=["Testing Utilities"]
@@ -159,4 +162,53 @@ async def verify_whatsapp_number(
159162
logger.error(f"Error verifying WhatsApp number for user {user_id}: {e}", exc_info=True)
160163
if isinstance(e, HTTPException):
161164
raise e
162-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
165+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
166+
@router.post("/reprocess-onboarding", summary="Manually re-process onboarding data into memories")
167+
async def reprocess_onboarding_data(user_id: str = Depends(auth_helper.get_current_user_id)):
168+
_check_allowed_environments(
169+
["dev-local", "selfhost"],
170+
"This endpoint is only available in development or self-host environments."
171+
)
172+
try:
173+
user_profile = await mongo_manager.get_user_profile(user_id)
174+
if not user_profile or "userData" not in user_profile:
175+
raise HTTPException(status_code=404, detail="User profile not found.")
176+
177+
user_data = user_profile["userData"]
178+
onboarding_data = user_data.get("onboardingAnswers")
179+
if not onboarding_data:
180+
raise HTTPException(status_code=404, detail="No onboarding data found for this user.")
181+
182+
fact_templates = {
183+
"user-name": "The user's name is {}.",
184+
"location": "The user's location is around '{}'.", # Simplified for string location
185+
"timezone": "The user's timezone is {}.",
186+
"professional-context": "Professionally, the user has shared: {}",
187+
"personal-context": "Personally, the user is interested in: {}",
188+
}
189+
190+
onboarding_facts = []
191+
for key, value in onboarding_data.items():
192+
if not value or key not in fact_templates:
193+
continue
194+
195+
fact = ""
196+
if key == "location":
197+
if isinstance(value, dict) and value.get('latitude') is not None:
198+
fact = f"The user's location is at latitude {value.get('latitude')}, longitude {value.get('longitude')}."
199+
elif isinstance(value, str) and value.strip():
200+
fact = fact_templates[key].format(value)
201+
elif isinstance(value, str) and value.strip():
202+
fact = fact_templates[key].format(value)
203+
204+
if fact:
205+
onboarding_facts.append(fact)
206+
207+
for fact in onboarding_facts:
208+
cud_memory_task.delay(user_id, fact, source="onboarding_reprocess")
209+
210+
return JSONResponse(content={"message": f"Successfully queued {len(onboarding_facts)} facts from onboarding data for memory processing."})
211+
212+
except Exception as e:
213+
logger.error(f"Error reprocessing onboarding data for user {user_id}: {e}", exc_info=True)
214+
raise HTTPException(status_code=500, detail=str(e))

src/server/mcp_hub/memory/db.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
# Dictionary to store connection pools, keyed by the event loop they belong to.
3232
_pools: Dict[asyncio.AbstractEventLoop, asyncpg.Pool] = {}
33+
_db_setup_lock = asyncio.Lock()
3334

3435
async def get_db_pool() -> asyncpg.Pool:
3536
"""Initializes and returns a singleton PostgreSQL connection pool for the current event loop."""
@@ -40,21 +41,28 @@ async def get_db_pool() -> asyncpg.Pool:
4041
pool = _pools.get(loop)
4142

4243
# If the pool doesn't exist for this loop, or it's closing, create a new one.
43-
if pool is None or pool.is_closing():
44-
if not all([POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB]):
45-
raise ValueError("PostgreSQL connection details are not configured in the environment.")
46-
47-
logger.info(f"Initializing PostgreSQL connection pool for db: {POSTGRES_DB} on event loop {id(loop)}.")
48-
pool = await asyncpg.create_pool(
49-
user=POSTGRES_USER,
50-
password=POSTGRES_PASSWORD,
51-
database=POSTGRES_DB,
52-
host=POSTGRES_HOST,
53-
port=POSTGRES_PORT,
54-
)
55-
_pools[loop] = pool # Store the new pool in the dictionary
56-
else:
57-
logger.debug(f"Returning existing PostgreSQL connection pool for event loop {id(loop)}.")
44+
async with _db_setup_lock:
45+
if pool is None or pool.is_closing():
46+
if not all([POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB]):
47+
raise ValueError("PostgreSQL connection details are not configured in the environment.")
48+
49+
logger.info(f"Initializing PostgreSQL connection pool for db: {POSTGRES_DB} on event loop {id(loop)}.")
50+
try:
51+
pool = await asyncpg.create_pool(
52+
user=POSTGRES_USER,
53+
password=POSTGRES_PASSWORD,
54+
database=POSTGRES_DB,
55+
host=POSTGRES_HOST,
56+
port=POSTGRES_PORT,
57+
)
58+
await setup_database(pool) # Run setup right after creating the pool
59+
_pools[loop] = pool # Store the new pool in the dictionary
60+
logger.info("PostgreSQL connection pool and schema initialized successfully.")
61+
except Exception as e:
62+
logger.error(f"Failed to create PostgreSQL connection pool: {e}", exc_info=True)
63+
raise
64+
else:
65+
logger.debug(f"Returning existing PostgreSQL connection pool for event loop {id(loop)}.")
5866
return pool
5967

6068
async def close_db_pool_for_loop(loop: asyncio.AbstractEventLoop):
@@ -82,9 +90,8 @@ async def close_db_pool():
8290
else:
8391
logger.debug("No PostgreSQL connection pools to close.")
8492

85-
async def setup_database():
93+
async def setup_database(pool: asyncpg.Pool):
8694
"""Ensures all necessary tables, indexes, and static data are created in the database."""
87-
pool = await get_db_pool()
8895
async with pool.acquire() as connection:
8996
logger.info("Acquired DB connection for database setup.")
9097
async with connection.transaction():

src/server/workers/tasks.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,55 +122,53 @@ def run_async(coro):
122122
loop.close()
123123
asyncio.set_event_loop(None)
124124

125-
@celery_app.task(name="cud_memory_task")
126-
def cud_memory_task(user_id: str, information: str, source: Optional[str] = None):
127-
"""
128-
Celery task wrapper for the CUD (Create, Update, Delete) memory operation.
129-
This runs the core memory management logic asynchronously.
130-
"""
131-
logger.info(f"Celery worker received cud_memory_task for user_id: {user_id}")
125+
async def async_cud_memory_task(user_id: str, information: str, source: Optional[str] = None):
126+
"""The async logic for the CUD memory task."""
132127
db_manager = MongoManager()
128+
username = user_id # Default fallback
133129
try:
134130
# --- Enforce Memory Limit ---
135-
user_profile = run_async(db_manager.get_user_profile(user_id))
131+
user_profile = await db_manager.get_user_profile(user_id)
136132
plan = user_profile.get("userData", {}).get("plan", "free") if user_profile else "free"
137133
limit = PLAN_LIMITS[plan].get("memories_total", 0)
138134

139135
if limit != float('inf'):
140136
from mcp_hub.memory import db as memory_db
141-
pool = run_async(memory_db.get_db_pool())
142-
async def count_facts():
143-
async with pool.acquire() as conn:
144-
return await conn.fetchval("SELECT COUNT(*) FROM facts WHERE user_id = $1", user_id)
145-
146-
current_count = run_async(count_facts())
147-
137+
pool = await memory_db.get_db_pool()
138+
async with pool.acquire() as conn:
139+
current_count = await conn.fetchval("SELECT COUNT(*) FROM facts WHERE user_id = $1", user_id)
148140
if current_count >= limit:
149141
logger.warning(f"User {user_id} on '{plan}' plan reached memory limit of {limit}. CUD operation aborted.")
150-
# Optionally notify the user
151-
run_async(notify_user(user_id, f"You've reached your memory limit of {limit} facts. Please upgrade to Pro for unlimited memories."))
142+
await notify_user(user_id, f"You've reached your memory limit of {limit} facts. Please upgrade to Pro for unlimited memories.")
152143
return
153-
except Exception as e:
154-
logger.error(f"Error checking memory limit for user {user_id}: {e}", exc_info=True)
155-
finally:
156-
run_async(db_manager.close())
157144

158-
# --- NEW: Fetch user's name before calling cud_memory ---
159-
username = user_id # Default fallback
160-
try:
161-
user_profile = run_async(db_manager.get_user_profile(user_id))
145+
# --- Fetch user's name before calling cud_memory ---
162146
if user_profile:
147+
# Use the name from personalInfo, which is set during onboarding and can be updated in settings.
163148
username = user_profile.get("userData", {}).get("personalInfo", {}).get("name", user_id)
149+
164150
except Exception as e:
165-
logger.error(f"Failed to fetch user profile for {user_id} in cud_memory_task: {e}")
151+
logger.error(f"Error during pre-CUD setup for user {user_id}: {e}", exc_info=True)
152+
# We can still proceed with the CUD operation, just using the user_id as the name.
166153
finally:
167-
run_async(db_manager.close())
168-
# --- END NEW ---
154+
await db_manager.close()
169155

156+
# Initialize models required for the CUD operation
170157
initialize_embedding_model()
171158
initialize_agents()
172159
# Pass the fetched username to the cud_memory function
173-
run_async(cud_memory(user_id, information, source, username))
160+
await cud_memory(user_id, information, source, username)
161+
162+
@celery_app.task(name="cud_memory_task")
163+
def cud_memory_task(user_id: str, information: str, source: Optional[str] = None):
164+
"""
165+
Celery task wrapper for the CUD (Create, Update, Delete) memory operation.
166+
This runs the core memory management logic asynchronously.
167+
"""
168+
logger.info(f"Celery worker received cud_memory_task for user_id: {user_id}")
169+
# This single call to run_async wraps the entire asynchronous logic,
170+
# ensuring the event loop and DB connections are managed correctly for the task's lifecycle.
171+
run_async(async_cud_memory_task(user_id, information, source))
174172

175173
@celery_app.task(name="orchestrate_swarm_task")
176174
def orchestrate_swarm_task(task_id: str, user_id: str):

0 commit comments

Comments
 (0)