Skip to content

Commit 6326b2a

Browse files
author
dori
committed
adding memory layer
1 parent 0b7639b commit 6326b2a

17 files changed

+2162
-125
lines changed

config.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"database": {
3+
"provider": "in_memory",
4+
"url": "",
5+
"max_context_records": 20,
6+
"context_enrichment_count": 5
7+
},
8+
"enable_llm_fallback": true
9+
}

judge_mcp_flow.md

Lines changed: 607 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"pydantic>=2.0.0",
3333
"jinja2>=3.1.0",
3434
"litellm>=1.0.0",
35+
"aiosqlite>=0.19.0",
3536
]
3637

3738
[project.urls]

src/mcp_as_a_judge/config.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
Configuration management for MCP as a Judge.
3+
4+
This module handles loading and managing configuration from various sources
5+
including config files, environment variables, and defaults.
6+
"""
7+
8+
import json
9+
import os
10+
from pathlib import Path
11+
from typing import Optional
12+
13+
from pydantic import BaseModel, Field
14+
15+
from .models import DatabaseConfig
16+
17+
18+
class Config(BaseModel):
19+
"""Main configuration model for the application."""
20+
21+
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
22+
enable_llm_fallback: bool = Field(
23+
default=True,
24+
description="Whether to enable LLM fallback when MCP sampling is not available"
25+
)
26+
27+
28+
def load_config(config_path: Optional[str] = None) -> Config:
29+
"""
30+
Load configuration from file and environment variables.
31+
32+
Args:
33+
config_path: Path to config file. If None, looks for config.json in current directory
34+
35+
Returns:
36+
Config object with loaded settings
37+
"""
38+
# Default configuration
39+
config_data = {}
40+
41+
# Try to load from config file
42+
if config_path is None:
43+
config_path = "config.json"
44+
45+
config_file = Path(config_path)
46+
if config_file.exists():
47+
try:
48+
with open(config_file, 'r') as f:
49+
config_data = json.load(f)
50+
except (json.JSONDecodeError, IOError) as e:
51+
print(f"Warning: Could not load config file {config_path}: {e}")
52+
53+
# Override with environment variables if present
54+
db_provider = os.getenv("MCP_JUDGE_DB_PROVIDER")
55+
if db_provider:
56+
if "database" not in config_data:
57+
config_data["database"] = {}
58+
config_data["database"]["provider"] = db_provider
59+
60+
db_url = os.getenv("MCP_JUDGE_DB_URL")
61+
if db_url:
62+
if "database" not in config_data:
63+
config_data["database"] = {}
64+
config_data["database"]["url"] = db_url
65+
66+
max_context = os.getenv("MCP_JUDGE_MAX_CONTEXT_RECORDS")
67+
if max_context:
68+
if "database" not in config_data:
69+
config_data["database"] = {}
70+
try:
71+
config_data["database"]["max_context_records"] = int(max_context)
72+
except ValueError:
73+
print(f"Warning: Invalid value for MCP_JUDGE_MAX_CONTEXT_RECORDS: {max_context}")
74+
75+
llm_fallback = os.getenv("MCP_JUDGE_ENABLE_LLM_FALLBACK")
76+
if llm_fallback:
77+
config_data["enable_llm_fallback"] = llm_fallback.lower() in ("true", "1", "yes", "on")
78+
79+
return Config(**config_data)
80+
81+
82+
def create_default_config_file(config_path: str = "config.json") -> None:
83+
"""
84+
Create a default configuration file.
85+
86+
Args:
87+
config_path: Path where to create the config file
88+
"""
89+
default_config = {
90+
"database": {
91+
"provider": "in_memory",
92+
"url": "",
93+
"max_context_records": 10
94+
},
95+
"enable_llm_fallback": True
96+
}
97+
98+
with open(config_path, 'w') as f:
99+
json.dump(default_config, f, indent=2)
100+
101+
print(f"Created default configuration file: {config_path}")
102+
103+
104+
def get_database_provider_from_url(url: str) -> str:
105+
"""
106+
Determine database provider from URL.
107+
108+
Args:
109+
url: Database connection URL
110+
111+
Returns:
112+
Provider name: 'sqlite', 'postgresql', or 'in_memory'
113+
"""
114+
if not url:
115+
return "in_memory"
116+
117+
url_lower = url.lower()
118+
if url_lower.startswith("sqlite://") or url_lower.endswith(".db"):
119+
return "sqlite"
120+
elif url_lower.startswith("postgresql://") or url_lower.startswith("postgres://"):
121+
return "postgresql"
122+
else:
123+
# Default to in_memory for unknown URLs
124+
return "in_memory"
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Conversation History Service for MCP Judge Tools.
3+
4+
This service handles:
5+
1. Loading historical context for LLM enrichment
6+
2. Saving tool interactions as conversation records
7+
3. Managing session-based conversation history
8+
"""
9+
10+
from typing import List, Optional
11+
12+
from .config import Config
13+
from .db import ConversationHistoryDB, ConversationRecord, create_database_provider
14+
15+
16+
class ConversationHistoryService:
17+
"""Service for managing conversation history in judge tools."""
18+
19+
def __init__(self, config: Config, db_provider: Optional[ConversationHistoryDB] = None):
20+
"""
21+
Initialize the conversation history service.
22+
23+
Args:
24+
config: Application configuration
25+
db_provider: Optional database provider (will create one if not provided)
26+
"""
27+
self.config = config
28+
self.db = db_provider or create_database_provider(config)
29+
30+
async def load_context_for_enrichment(self, session_id: str) -> tuple[List[ConversationRecord], List[str]]:
31+
"""
32+
Load recent conversation records for LLM context enrichment.
33+
34+
Args:
35+
session_id: Session identifier
36+
37+
Returns:
38+
Tuple of (conversation_records, conversation_ids)
39+
- conversation_records: Full records for LLM context
40+
- conversation_ids: Just the IDs for saving in new record's context field
41+
"""
42+
count = self.config.database.context_enrichment_count
43+
44+
# Load recent conversations for this session
45+
recent_records = await self.db.get_session_conversations(
46+
session_id=session_id,
47+
limit=count
48+
)
49+
50+
# Extract just the IDs for context reference
51+
context_ids = [record.id for record in recent_records]
52+
53+
return recent_records, context_ids
54+
55+
async def save_tool_interaction(
56+
self,
57+
session_id: str,
58+
tool_name: str,
59+
tool_input: str,
60+
tool_output: str,
61+
context_ids: List[str]
62+
) -> str:
63+
"""
64+
Save a tool interaction as a conversation record.
65+
66+
Args:
67+
session_id: Session identifier from AI agent
68+
tool_name: Name of the judge tool (e.g., 'judge_coding_plan')
69+
tool_input: Input that was passed to the tool
70+
tool_output: Output/result from the tool
71+
context_ids: IDs of conversation records that were used for context enrichment
72+
73+
Returns:
74+
ID of the created conversation record
75+
"""
76+
record_id = await self.db.save_conversation(
77+
session_id=session_id,
78+
source=tool_name,
79+
input_data=tool_input,
80+
context=context_ids,
81+
output=tool_output
82+
)
83+
84+
return record_id
85+
86+
def format_context_for_llm(self, context_records: List[ConversationRecord]) -> str:
87+
"""
88+
Format conversation history for LLM context enrichment.
89+
90+
Args:
91+
context_records: Recent conversation records
92+
93+
Returns:
94+
Formatted context string for LLM
95+
"""
96+
if not context_records:
97+
return "No previous conversation history available."
98+
99+
context_lines = ["## Previous Conversation History"]
100+
context_lines.append("Here are the recent interactions in this session for context:")
101+
context_lines.append("")
102+
103+
# Format records (most recent first)
104+
for i, record in enumerate(context_records, 1):
105+
context_lines.append(f"### {i}. {record.source} ({record.timestamp.strftime('%Y-%m-%d %H:%M:%S')})")
106+
context_lines.append(f"**Input:** {record.input}")
107+
context_lines.append(f"**Output:** {record.output}")
108+
context_lines.append("")
109+
110+
context_lines.append("---")
111+
context_lines.append("Use this context to make more informed decisions.")
112+
context_lines.append("")
113+
114+
return "\n".join(context_lines)
115+
116+
async def get_session_summary(self, session_id: str) -> dict:
117+
"""
118+
Get a summary of the session's conversation history.
119+
120+
Args:
121+
session_id: Session identifier
122+
123+
Returns:
124+
Dictionary with session statistics
125+
"""
126+
all_records = await self.db.get_session_conversations(session_id)
127+
128+
# Count by tool type
129+
tool_counts = {}
130+
for record in all_records:
131+
tool_counts[record.source] = tool_counts.get(record.source, 0) + 1
132+
133+
return {
134+
"session_id": session_id,
135+
"total_interactions": len(all_records),
136+
"tool_usage": tool_counts,
137+
"latest_interaction": all_records[0].timestamp.isoformat() if all_records else None,
138+
"context_enrichment_count": self.config.database.context_enrichment_count,
139+
"max_context_records": self.config.database.max_context_records
140+
}
141+
142+
143+
# Convenience functions for easy integration with existing tools
144+
145+
async def enrich_with_context(
146+
service: ConversationHistoryService,
147+
session_id: str,
148+
base_prompt: str
149+
) -> tuple[str, List[str]]:
150+
"""
151+
Enrich a base prompt with conversation history context.
152+
153+
Args:
154+
service: ConversationHistoryService instance
155+
session_id: Session identifier
156+
base_prompt: Original prompt to enrich
157+
158+
Returns:
159+
Tuple of (enriched_prompt, context_ids)
160+
"""
161+
context_records, context_ids = await service.load_context_for_enrichment(session_id)
162+
context_text = service.format_context_for_llm(context_records)
163+
164+
enriched_prompt = f"{context_text}\n## Current Request\n{base_prompt}"
165+
166+
return enriched_prompt, context_ids
167+
168+
169+
async def save_tool_result(
170+
service: ConversationHistoryService,
171+
session_id: str,
172+
tool_name: str,
173+
original_input: str,
174+
tool_result: str,
175+
context_ids: List[str]
176+
) -> str:
177+
"""
178+
Save a tool's result to conversation history.
179+
180+
Args:
181+
service: ConversationHistoryService instance
182+
session_id: Session identifier
183+
tool_name: Name of the tool
184+
original_input: Original input to the tool
185+
tool_result: Result from the tool
186+
context_ids: Context IDs that were used for enrichment
187+
188+
Returns:
189+
ID of the saved conversation record
190+
"""
191+
return await service.save_tool_interaction(
192+
session_id=session_id,
193+
tool_name=tool_name,
194+
tool_input=original_input,
195+
tool_output=tool_result,
196+
context_ids=context_ids
197+
)

src/mcp_as_a_judge/db/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Database abstraction layer for MCP as a Judge.
3+
4+
This module provides database interfaces and providers for storing
5+
conversation history and tool interactions.
6+
"""
7+
8+
from .factory import DatabaseFactory, create_database_provider
9+
from .interface import ConversationHistoryDB, ConversationRecord
10+
from .providers import InMemoryProvider
11+
12+
__all__ = [
13+
"ConversationHistoryDB",
14+
"ConversationRecord",
15+
"InMemoryProvider",
16+
"DatabaseFactory",
17+
"create_database_provider"
18+
]

0 commit comments

Comments
 (0)