Skip to content

Commit f0936a8

Browse files
committed
memory usage optimization
1 parent d65753e commit f0936a8

File tree

7 files changed

+180
-59
lines changed

7 files changed

+180
-59
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,7 @@ htmlcov/
6262
*.sqlite3
6363

6464
# Vector database
65-
chroma_db/
65+
chroma_db/
66+
67+
# Cursor
68+
.cursor/

app.py

Lines changed: 133 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
"""
77

88
import logging
9+
import os
910

1011
import chainlit as cl
12+
import psutil # For memory tracking
1113
from langchain_core.runnables.config import RunnableConfig
1214
from limits import parse
1315
from limits.storage import MemoryStorage
@@ -32,6 +34,38 @@
3234
# logging.getLogger("openai").setLevel(logging.WARNING)
3335
logger = logging.getLogger(__name__)
3436

37+
# --- Memory Management Constants ---
38+
# Maximum number of message pairs (user+assistant) to keep in memory
39+
MAX_HISTORY_LENGTH = settings.MAX_HISTORY_LENGTH
40+
41+
42+
# Function to trim message history to prevent memory bloat
43+
def trim_message_history(session_id: str) -> None:
44+
"""
45+
Trims the message history for a session when it gets too long.
46+
This helps prevent memory buildup in long conversations.
47+
48+
Args:
49+
session_id: The ID of the current session
50+
"""
51+
try:
52+
# Get current message history
53+
history = cl.user_session.get("message_history", [])
54+
55+
# If history exceeds max length, trim it
56+
if len(history) > MAX_HISTORY_LENGTH * 2: # Each exchange has 2 messages
57+
# Keep only the most recent messages
58+
history = history[-MAX_HISTORY_LENGTH * 2 :]
59+
cl.user_session.set("message_history", history)
60+
logger.info(
61+
f"Trimmed message history for session {session_id} "
62+
f"to {len(history)} messages"
63+
)
64+
except Exception as e:
65+
# Log but don't crash if history trimming fails
66+
logger.warning(f"Failed to trim message history: {e}")
67+
68+
3569
# --- Global Initialization ---
3670
# Declare placeholders for global objects
3771
prompt_manager = None
@@ -88,6 +122,57 @@ def get_session_id():
88122
# Remove the slowapi limiter instance for messages
89123
# message_limiter = Limiter(key_func=get_session_id) # REMOVED
90124

125+
126+
# --- Helper Functions for Message Processing ---
127+
async def check_initialization() -> bool:
128+
"""Check if the application is properly initialized."""
129+
if not INITIALIZATION_SUCCESSFUL:
130+
await cl.ErrorMessage(content="Application not initialized.").send()
131+
return False
132+
return True
133+
134+
135+
async def get_translation_service():
136+
"""Get the translation service from the user session."""
137+
service = cl.user_session.get("translation_service")
138+
if not service:
139+
logger.error("TranslationService not found in user session.")
140+
await cl.ErrorMessage(
141+
content="Error: Translation service unavailable. "
142+
"Please restart the chat."
143+
).send()
144+
return None
145+
return service
146+
147+
148+
async def perform_translation(service, message_content, config):
149+
"""Perform the actual translation using the service."""
150+
if settings.DEBUG:
151+
# When debugging, let the callback handler manage steps
152+
return await service.translate_text(message_content, config=config)
153+
else:
154+
# When not debugging, show a simple progress step
155+
async with cl.Step(name="Translating..."):
156+
# Config will have empty callbacks list here
157+
return await service.translate_text(message_content, config=config)
158+
159+
160+
async def log_memory_usage(session_id):
161+
"""Log current memory usage for monitoring."""
162+
try:
163+
# Use psutil to get memory info
164+
process = psutil.Process(os.getpid())
165+
memory_info = process.memory_info()
166+
memory_mb = memory_info.rss / 1024 / 1024 # Convert to MB
167+
logger.info(f"Memory usage: {memory_mb:.2f} MB for session {session_id}")
168+
169+
# If memory usage is high, log a warning
170+
if memory_mb > 400: # 400MB is getting close to the 512MB limit
171+
logger.warning(f"High memory usage detected: {memory_mb:.2f} MB")
172+
except Exception as e:
173+
logger.error(f"Failed to log memory usage: {e}")
174+
175+
91176
# --- Chainlit Event Handlers ---
92177

93178

@@ -119,77 +204,63 @@ async def start():
119204
# @message_limiter.limit("5/minute") # REMOVED Decorator
120205
async def on_message(message: cl.Message):
121206
"""Handle incoming text messages and provide translations."""
122-
# --- MANUAL Rate Limit Check (using 'limits' library directly) --- <<< CORRECTED
207+
# --- Rate Limit Check ---
123208
session_id = get_session_id()
124-
# Use the limits strategy's hit() method. It returns False if the limit is exceeded.
125209
if not message_limit_strategy.hit(message_rate_limit, session_id):
126210
# Limit exceeded
127211
logger.warning(f"Rate limit exceeded for session {session_id}")
128212
await cl.ErrorMessage(
129213
content="Rate limit exceeded (5 messages per minute). Please wait a moment."
130214
).send()
131-
return # Stop processing this message
132-
# --- End of Rate Limit Check ---
215+
return
133216

134-
# Proceed with message handling only if the rate limit check passed
135-
try:
136-
# REMOVED await message_limiter.hit("5/minute", get_session_id())
217+
# --- Memory Management ---
218+
trim_message_history(session_id)
137219

138-
if not INITIALIZATION_SUCCESSFUL:
139-
await cl.ErrorMessage(content="Application not initialized.").send()
140-
return
220+
# Track this message in history
221+
history = cl.user_session.get("message_history", [])
222+
history.append({"role": "user", "content": message.content})
223+
cl.user_session.set("message_history", history)
141224

142-
service = cl.user_session.get("translation_service")
225+
try:
226+
# Basic validations
227+
if not await check_initialization():
228+
return
143229

230+
service = await get_translation_service()
144231
if not service:
145-
logger.error("TranslationService not found in user session.")
146-
await cl.ErrorMessage(
147-
content="Error: Translation service unavailable. "
148-
"Please restart the chat."
149-
).send()
150232
return
151233

152234
if not message.content:
153235
logger.warning("Received empty message.")
154-
return # Ignore empty messages
236+
return
155237

156-
# Conditionally add the callback handler for step visibility
238+
# Setup for translation
157239
callbacks = []
158240
if settings.DEBUG:
159241
callbacks.append(cl.LangchainCallbackHandler())
160-
logger.info(
161-
"Debug enabled: Adding LangchainCallbackHandler for step visibility."
162-
)
242+
logger.info("Debug enabled: Adding LangchainCallbackHandler.")
163243

164244
config = RunnableConfig(callbacks=callbacks)
165245

166-
# Use the service to translate, passing the config (with or without callbacks)
167-
if settings.DEBUG:
168-
# When debugging, let the callback handler manage steps
169-
translation_result = await service.translate_text(
170-
message.content, config=config
171-
)
172-
else:
173-
# When not debugging, show a simple progress step
174-
async with cl.Step(name="Translating..."):
175-
# Config will have empty callbacks list here
176-
translation_result = await service.translate_text(
177-
message.content, config=config
178-
)
179-
# Optionally set step output
180-
# (might be redundant if result is sent immediately after)
181-
# step.output = translation_result
182-
183-
# Send the final translation result
246+
# Perform translation
247+
translation_result = await perform_translation(service, message.content, config)
248+
249+
# Send result
184250
await cl.Message(content=f"Translation: {translation_result}").send()
185251

252+
# Update history
253+
history = cl.user_session.get("message_history", [])
254+
history.append(
255+
{"role": "assistant", "content": f"Translation: {translation_result}"}
256+
)
257+
cl.user_session.set("message_history", history)
258+
186259
except TranslationError as e:
187260
logger.error(
188261
f"Translation failed for '{message.content[:50]}...': {e}", exc_info=False
189-
) # exc_info=False to avoid redundant stack trace from service layer
190-
await cl.ErrorMessage(
191-
content=f"Sorry, translation failed: {e}"
192-
).send() # Show specific error if safe
262+
)
263+
await cl.ErrorMessage(content=f"Sorry, translation failed: {e}").send()
193264
except AppError as e:
194265
logger.error(
195266
f"Service error during translation for '{message.content[:50]}...': {e}",
@@ -198,20 +269,30 @@ async def on_message(message: cl.Message):
198269
await cl.ErrorMessage(
199270
content="Sorry, an application error occurred during translation."
200271
).send()
201-
except Exception as e: # Catch other potential exceptions from the core logic
202-
# This generic catch might now be redundant if specific errors are handled
203-
# but kept for safety, ensuring RateLimitExceeded is handled first.
272+
except Exception as e:
204273
logger.error(
205274
f"Unexpected error during translation for '{message.content[:50]}...': {e}",
206275
exc_info=True,
207276
)
208277
await cl.ErrorMessage(
209-
content=(
210-
"Sorry, an unexpected error occurred during translation. "
211-
"Please try again."
212-
)
278+
content="Sorry, an unexpected error occurred during translation."
213279
).send()
280+
finally:
281+
# Log memory usage
282+
await log_memory_usage(session_id)
283+
284+
285+
@cl.on_chat_end
286+
async def on_chat_end():
287+
"""Clean up resources when a chat session ends."""
288+
try:
289+
# Get the session ID for logging
290+
session_id = cl.context.session.id
291+
logger.info(f"Cleaning up resources for ending session {session_id}")
214292

293+
# Clear user session data to free memory
294+
cl.user_session.clear()
215295

216-
# Removed @cl.on_settings_update as it wasn't used after refactor
217-
# Removed @cl.on_chat_end/@cl.on_stop as they were empty
296+
logger.info(f"Successfully cleaned up resources for session {session_id}")
297+
except Exception as e:
298+
logger.error(f"Error during session cleanup: {e}", exc_info=True)

config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ class Settings(BaseSettings):
3838
# 4. More user-friendly for development with automatic persistence
3939
CHROMA_PERSIST_DIRECTORY: str = "chroma_db"
4040

41+
# --- Memory Management Configuration ---
42+
# Maximum number of documents to retrieve for context (smaller = less memory)
43+
MAX_RETRIEVAL_DOCS: int = 3
44+
# Number of documents to process in a batch during vector store creation
45+
VECTORSTORE_BATCH_SIZE: int = 50
46+
# Maximum history length (in message pairs) for chat sessions
47+
MAX_HISTORY_LENGTH: int = 15
48+
4149
# --- Prompt Configuration ---
4250
PROMPTS_DIR: str = "prompts"
4351
SYSTEM_PROMPT_FILE: str = "system.md"

core/data_loader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,16 @@ def _create_vector_store(documents: List[Document], api_key: str) -> VectorStore
171171
# Ensure directory exists
172172
os.makedirs(settings.CHROMA_PERSIST_DIRECTORY, exist_ok=True)
173173

174+
# Add batch_size parameter to control memory usage during indexing
174175
vector_store = Chroma.from_documents(
175176
documents=documents,
176177
embedding=embedding_model,
177178
persist_directory=settings.CHROMA_PERSIST_DIRECTORY,
179+
collection_metadata={
180+
"hnsw:space": "cosine"
181+
}, # More efficient distance calculation
182+
# Process documents in smaller batches to reduce peak memory usage
183+
batch_size=settings.VECTORSTORE_BATCH_SIZE,
178184
)
179185
# Persist to disk
180186
vector_store.persist()

core/translator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def __init__(
6767
streaming=True, # Enable streaming by default if needed later
6868
)
6969

70-
self.retriever = self._create_retriever()
70+
# Create retriever with configured number of documents to limit memory usage
71+
self.retriever = self._create_retriever(k=settings.MAX_RETRIEVAL_DOCS)
7172
self.chain = self._build_rag_chain()
7273
logger.info("ArgentinianTranslator initialized successfully.")
7374

@@ -123,7 +124,18 @@ def replace_match(match):
123124
def _create_retriever(self, k: int = 3):
124125
"""Creates a retriever from the vector store."""
125126
logger.debug(f"Creating retriever with k={k}")
126-
return self.vector_store.as_retriever(search_kwargs={"k": k})
127+
# Add memory management for retrieval
128+
return self.vector_store.as_retriever(
129+
search_kwargs={
130+
"k": k,
131+
# Limit fetch size to reduce memory usage
132+
"fetch_k": k * 3,
133+
# Use more efficient MMR retrieval
134+
# that removes duplicates to save memory
135+
"search_type": "mmr",
136+
"lambda_mult": 0.8, # Controls diversity (higher = more diversity)
137+
}
138+
)
127139

128140
def _format_retrieved_docs(self, docs: List[Document]) -> str:
129141
"""Formats retrieved documents into a string for the prompt context."""

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,3 +1262,5 @@ yarl==1.18.3 \
12621262
zipp==3.21.0 \
12631263
--hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \
12641264
--hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931
1265+
memory-profiler==0.61.0 \
1266+
--hash=sha256:97c82e7e66a05ad5e1f2d0dfd23eae374cb1ab8aca87d9e0c27f03ab74fcef3d

services/translation_service.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
# --- Language Detection Prompt Template (Keep it minimal) ---
3333
# Use a single triple-quoted string for clarity and correct parsing
3434
LANG_DETECT_PROMPT_TEMPLATE = ChatPromptTemplate.from_template(
35-
f"""Identify the primary language of the following text. \
35+
"""Identify the primary language of the following text. \
3636
Respond with ONLY the two-letter ISO 639-1 language code (e.g., 'en', 'es', 'fr'). \
3737
If you are unsure, the text is nonsensical, gibberish, or not a real language, \
38-
respond with '{UNKNOWN_LANG_CODE}'. \
38+
respond with '{unknown_lang_code}'. \
3939
Text:
4040
\"\"\"
4141
{{user_input}}
4242
\"\"\"
43-
Language code:"""
43+
Language code:""".format(unknown_lang_code=UNKNOWN_LANG_CODE)
4444
)
4545

4646
# Seed langdetect for consistent results
@@ -113,8 +113,17 @@ async def _detect_language_llm(self, text: str) -> str:
113113
"""Detect language using the LLM chain."""
114114
logger.debug(f"Using LLM for language detection for: '{text[:50]}...'")
115115
try:
116+
# Use a shortened version of the text for language detection to save tokens
117+
# For very short texts, use the whole text
118+
if len(text) > 100:
119+
detection_text = text[:100] # Only use the first 100 characters
120+
else:
121+
detection_text = text
122+
116123
# Use ainvoke for the async context
117-
result = await self._lang_detect_chain.ainvoke({"user_input": text})
124+
result = await self._lang_detect_chain.ainvoke(
125+
{"user_input": detection_text}
126+
)
118127
# Clean up potential whitespace and normalize case
119128
detected_lang = result.strip().lower()
120129
logger.debug(f"LLM detection result: {detected_lang}")

0 commit comments

Comments
 (0)