Skip to content

Commit f65f622

Browse files
author
dori
committed
feat: refactor ai code
1 parent 475da21 commit f65f622

File tree

8 files changed

+107
-156
lines changed

8 files changed

+107
-156
lines changed

src/mcp_as_a_judge/constants.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,5 @@
1515
DATABASE_URL = "sqlite://:memory:"
1616
MAX_SESSION_RECORDS = 20 # Maximum records to keep per session (FIFO)
1717
MAX_TOTAL_SESSIONS = 50 # Maximum total sessions to keep (LRU cleanup)
18-
MAX_CONTEXT_TOKENS = (
19-
50000 # Maximum tokens for conversation history context (1 token ≈ 4 characters)
20-
)
18+
MAX_CONTEXT_TOKENS = 50000 # Maximum tokens for conversation history context (1 token ≈ 4 characters)
19+

src/mcp_as_a_judge/db/conversation_history_service.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from mcp_as_a_judge.db.db_config import Config
1616
from mcp_as_a_judge.logging_config import get_logger
17-
from mcp_as_a_judge.utils.token_utils import filter_records_by_token_limit
17+
from mcp_as_a_judge.db.token_utils import filter_records_by_token_limit
1818

1919
# Set up logger
2020
logger = get_logger(__name__)
@@ -36,18 +36,17 @@ def __init__(
3636
self.config = config
3737
self.db = db_provider or create_database_provider(config)
3838

39-
async def load_context_for_enrichment(
40-
self, session_id: str
41-
) -> list[ConversationRecord]:
39+
async def load_context_for_enrichment(self, session_id: str, current_prompt: str = "") -> list[ConversationRecord]:
4240
"""
4341
Load recent conversation records for LLM context enrichment.
4442
4543
Two-level filtering approach:
4644
1. Database already enforces storage limits (record count + token limits)
47-
2. Load-time filtering ensures history + current fits within LLM context limits
45+
2. Load-time filtering ensures history + current prompt fits within LLM context limits
4846
4947
Args:
5048
session_id: Session identifier
49+
current_prompt: Current prompt that will be sent to LLM (for token calculation)
5150
5251
Returns:
5352
List of conversation records for LLM context (filtered for LLM limits)
@@ -62,18 +61,23 @@ async def load_context_for_enrichment(
6261

6362
# Apply LLM context filtering: ensure history + current prompt will fit within token limit
6463
# This filters the list without modifying the database (only token limit matters for LLM)
65-
filtered_records = filter_records_by_token_limit(recent_records)
64+
filtered_records = filter_records_by_token_limit(recent_records, current_prompt=current_prompt)
6665

6766
logger.info(
6867
f"✅ Returning {len(filtered_records)} conversation records for LLM context"
6968
)
7069
return filtered_records
7170

72-
async def save_tool_interaction(
71+
async def save_tool_interaction_and_cleanup(
7372
self, session_id: str, tool_name: str, tool_input: str, tool_output: str
7473
) -> str:
7574
"""
76-
Save a tool interaction as a conversation record.
75+
Save a tool interaction as a conversation record and perform automatic cleanup.in the provider layer
76+
77+
After saving, the database provider automatically performs cleanup to enforce limits:
78+
- Removes old records if session exceeds MAX_SESSION_RECORDS (20)
79+
- Removes old records if session exceeds MAX_CONTEXT_TOKENS (50,000)
80+
- Removes least recently used sessions if total sessions exceed MAX_TOTAL_SESSIONS (50)
7781
7882
Args:
7983
session_id: Session identifier from AI agent
@@ -98,31 +102,23 @@ async def save_tool_interaction(
98102
logger.info(f"✅ Saved conversation record with ID: {record_id}")
99103
return record_id
100104

101-
async def get_conversation_history(
102-
self, session_id: str
103-
) -> list[ConversationRecord]:
105+
async def save_tool_interaction(
106+
self, session_id: str, tool_name: str, tool_input: str, tool_output: str
107+
) -> str:
104108
"""
105-
Get conversation history for a session to be injected into user prompts.
106-
107-
Args:
108-
session_id: Session identifier
109+
Save a tool interaction as a conversation record.
109110
110-
Returns:
111-
List of conversation records for the session (most recent first)
111+
DEPRECATED: Use save_tool_interaction_and_cleanup() instead.
112+
This method is kept for backward compatibility.
112113
"""
113-
logger.info(f"🔄 Loading conversation history for session {session_id}")
114-
115-
context_records = await self.load_context_for_enrichment(session_id)
116-
117-
logger.info(
118-
f"📝 Retrieved {len(context_records)} conversation records for session {session_id}"
114+
logger.warning(
115+
"save_tool_interaction() is deprecated. Use save_tool_interaction_and_cleanup() instead."
116+
)
117+
return await self.save_tool_interaction_and_cleanup(
118+
session_id, tool_name, tool_input, tool_output
119119
)
120120

121-
return context_records
122-
123-
def format_conversation_history_as_json_array(
124-
self, conversation_history: list[ConversationRecord]
125-
) -> list[dict]:
121+
def format_conversation_history_as_json_array( self, conversation_history: list[ConversationRecord]) -> list[dict]:
126122
"""
127123
Convert conversation history list to JSON array for prompt injection.
128124

src/mcp_as_a_judge/db/providers/sqlite_provider.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mcp_as_a_judge.db.cleanup_service import ConversationCleanupService
1616
from mcp_as_a_judge.db.interface import ConversationHistoryDB, ConversationRecord
1717
from mcp_as_a_judge.logging_config import get_logger
18-
from mcp_as_a_judge.utils.token_utils import calculate_record_tokens
18+
from mcp_as_a_judge.db.token_utils import calculate_record_tokens
1919

2020
# Set up logger
2121
logger = get_logger(__name__)
@@ -101,12 +101,15 @@ def _cleanup_old_messages(self, session_id: str) -> int:
101101
Two-step process:
102102
1. If record count > max_records, remove oldest record
103103
2. If total tokens > max_tokens, remove oldest records until within limit
104+
105+
Optimization: Single DB query with ORDER BY, then in-memory list operations.
106+
Eliminates 2 extra database queries compared to naive implementation.
104107
"""
105108
with Session(self.engine) as session:
106-
# Get current record count
109+
# Get current records ordered by timestamp DESC (newest first for token calculation)
107110
count_stmt = select(ConversationRecord).where(
108111
ConversationRecord.session_id == session_id
109-
)
112+
).order_by(desc(ConversationRecord.timestamp))
110113
current_records = session.exec(count_stmt).all()
111114
current_count = len(current_records)
112115

@@ -121,37 +124,25 @@ def _cleanup_old_messages(self, session_id: str) -> int:
121124
if current_count > self._max_session_records:
122125
logger.info(" 📊 Record limit exceeded, removing 1 oldest record")
123126

124-
# Get the oldest record to remove (since we add one by one, only need to remove one)
125-
oldest_stmt = (
126-
select(ConversationRecord)
127-
.where(ConversationRecord.session_id == session_id)
128-
.order_by(asc(ConversationRecord.timestamp))
129-
.limit(1)
127+
# Take the last record (oldest) since list is sorted by timestamp DESC (newest first)
128+
oldest_record = current_records[-1]
129+
130+
logger.info(
131+
f" 🗑️ Removing oldest record: {oldest_record.source} | {oldest_record.tokens} tokens | {oldest_record.timestamp}"
130132
)
131-
oldest_record = session.exec(oldest_stmt).first()
133+
session.delete(oldest_record)
134+
removed_count += 1
135+
session.commit()
136+
logger.info(" ✅ Removed 1 record due to record limit")
132137

133-
if oldest_record:
134-
logger.info(
135-
f" 🗑️ Removing oldest record: {oldest_record.source} | {oldest_record.tokens} tokens | {oldest_record.timestamp}"
136-
)
137-
session.delete(oldest_record)
138-
removed_count += 1
139-
session.commit()
140-
logger.info(" ✅ Removed 1 record due to record limit")
138+
# Update our in-memory list to reflect the deletion
139+
current_records.remove(oldest_record)
141140

142-
# STEP 2: Handle token limit (check remaining records after step 1)
143-
remaining_stmt = (
144-
select(ConversationRecord)
145-
.where(ConversationRecord.session_id == session_id)
146-
.order_by(
147-
desc(ConversationRecord.timestamp)
148-
) # Newest first for token calculation
149-
)
150-
remaining_records = session.exec(remaining_stmt).all()
151-
current_tokens = sum(record.tokens for record in remaining_records)
141+
# STEP 2: Handle token limit (list is already sorted newest first - perfect for token calculation)
142+
current_tokens = sum(record.tokens for record in current_records)
152143

153144
logger.info(
154-
f" 🔢 {len(remaining_records)} records, {current_tokens} tokens "
145+
f" 🔢 {len(current_records)} records, {current_tokens} tokens "
155146
f"(max: {MAX_CONTEXT_TOKENS})"
156147
)
157148

@@ -164,15 +155,15 @@ def _cleanup_old_messages(self, session_id: str) -> int:
164155
records_to_keep = []
165156
running_tokens = 0
166157

167-
for record in remaining_records: # Already ordered newest first
158+
for record in current_records: # Already ordered newest first
168159
if running_tokens + record.tokens <= MAX_CONTEXT_TOKENS:
169160
records_to_keep.append(record)
170161
running_tokens += record.tokens
171162
else:
172163
break
173164

174165
# Remove records that didn't make the cut
175-
records_to_remove_for_tokens = remaining_records[len(records_to_keep) :]
166+
records_to_remove_for_tokens = current_records[len(records_to_keep) :]
176167

177168
if records_to_remove_for_tokens:
178169
logger.info(

src/mcp_as_a_judge/utils/token_utils.py renamed to src/mcp_as_a_judge/db/token_utils.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from mcp_as_a_judge.constants import MAX_CONTEXT_TOKENS
99

10+
from mcp_as_a_judge.db.interface import ConversationRecord
11+
1012

1113
def calculate_tokens(text: str) -> int:
1214
"""
@@ -31,20 +33,18 @@ def calculate_tokens(text: str) -> int:
3133

3234
def calculate_record_tokens(input_text: str, output_text: str) -> int:
3335
"""
34-
Calculate total token count for a conversation record.
36+
Calculate total token count for input and output text.
3537
3638
Combines the token counts of input and output text.
3739
3840
Args:
39-
input_text: Tool input text
40-
output_text: Tool output text
41+
input_text: Input text string
42+
output_text: Output text string
4143
4244
Returns:
4345
Combined token count for both input and output
4446
"""
45-
input_tokens = calculate_tokens(input_text)
46-
output_tokens = calculate_tokens(output_text)
47-
return input_tokens + output_tokens
47+
return calculate_tokens(input_text) + calculate_tokens(output_text)
4848

4949

5050
def calculate_total_tokens(records: list) -> int:
@@ -61,7 +61,7 @@ def calculate_total_tokens(records: list) -> int:
6161

6262

6363
def filter_records_by_token_limit(
64-
records: list, max_tokens: int | None = None, max_records: int | None = None
64+
records: list, current_prompt: str = ""
6565
) -> list:
6666
"""
6767
Filter conversation records to stay within token and record limits.
@@ -71,36 +71,34 @@ def filter_records_by_token_limit(
7171
7272
Args:
7373
records: List of ConversationRecord objects (assumed to be in reverse chronological order)
74-
max_tokens: Maximum allowed token count (defaults to MAX_CONTEXT_TOKENS from constants)
7574
max_records: Maximum number of records to keep (optional)
75+
current_prompt: Current prompt that will be sent to LLM (for token calculation)
7676
7777
Returns:
7878
Filtered list of records that fit within the limits
7979
"""
8080
if not records:
8181
return []
8282

83-
# Use default token limit if not specified
84-
if max_tokens is None:
85-
max_tokens = MAX_CONTEXT_TOKENS
83+
# Calculate current prompt tokens
84+
current_prompt_tokens = calculate_record_tokens(current_prompt, "") if current_prompt else 0
8685

87-
# Apply record count limit first if specified
88-
if max_records is not None and len(records) > max_records:
89-
records = records[:max_records]
86+
# Calculate total tokens including current prompt
87+
history_tokens = calculate_total_tokens(records)
88+
total_tokens = history_tokens + current_prompt_tokens
9089

91-
# If total tokens are within limit, return all records
92-
total_tokens = calculate_total_tokens(records)
93-
if total_tokens <= max_tokens:
90+
# If total tokens (history + current prompt) are within limit, return all records
91+
if total_tokens <= MAX_CONTEXT_TOKENS:
9492
return records
9593

9694
# Remove oldest records (from the end since records are in reverse chronological order)
97-
# until we're within the token limit
95+
# until history + current prompt fit within the token limit
9896
filtered_records = records.copy()
99-
current_tokens = total_tokens
97+
current_history_tokens = history_tokens
10098

101-
while current_tokens > max_tokens and len(filtered_records) > 1:
99+
while (current_history_tokens + current_prompt_tokens) > MAX_CONTEXT_TOKENS and len(filtered_records) > 1:
102100
# Remove the oldest record (last in the list)
103101
removed_record = filtered_records.pop()
104-
current_tokens -= getattr(removed_record, "tokens", 0)
102+
current_history_tokens -= getattr(removed_record, "tokens", 0)
105103

106104
return filtered_records

0 commit comments

Comments
 (0)