Skip to content

Commit e78d07c

Browse files
author
dori
committed
fix security issue and reformat
1 parent d864dfc commit e78d07c

File tree

11 files changed

+217
-177
lines changed

11 files changed

+217
-177
lines changed

src/mcp_as_a_judge/config.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Config(BaseModel):
2020
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
2121
enable_llm_fallback: bool = Field(
2222
default=True,
23-
description="Whether to enable LLM fallback when MCP sampling is not available"
23+
description="Whether to enable LLM fallback when MCP sampling is not available",
2424
)
2525

2626

@@ -69,11 +69,18 @@ def load_config(config_path: str | None = None) -> Config:
6969
try:
7070
config_data["database"]["max_context_records"] = int(max_context)
7171
except ValueError:
72-
print(f"Warning: Invalid value for MCP_JUDGE_MAX_CONTEXT_RECORDS: {max_context}")
72+
print(
73+
f"Warning: Invalid value for MCP_JUDGE_MAX_CONTEXT_RECORDS: {max_context}"
74+
)
7375

7476
llm_fallback = os.getenv("MCP_JUDGE_ENABLE_LLM_FALLBACK")
7577
if llm_fallback:
76-
config_data["enable_llm_fallback"] = llm_fallback.lower() in ("true", "1", "yes", "on")
78+
config_data["enable_llm_fallback"] = llm_fallback.lower() in (
79+
"true",
80+
"1",
81+
"yes",
82+
"on",
83+
)
7784

7885
return Config(**config_data)
7986

@@ -86,15 +93,11 @@ def create_default_config_file(config_path: str = "config.json") -> None:
8693
config_path: Path where to create the config file
8794
"""
8895
default_config = {
89-
"database": {
90-
"provider": "in_memory",
91-
"url": "",
92-
"max_context_records": 10
93-
},
94-
"enable_llm_fallback": True
96+
"database": {"provider": "in_memory", "url": "", "max_context_records": 10},
97+
"enable_llm_fallback": True,
9598
}
9699

97-
with open(config_path, 'w') as f:
100+
with open(config_path, "w") as f:
98101
json.dump(default_config, f, indent=2)
99102

100103
print(f"Created default configuration file: {config_path}")

src/mcp_as_a_judge/db/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414
"ConversationRecord",
1515
"DatabaseFactory",
1616
"SQLiteProvider",
17-
"create_database_provider"
17+
"create_database_provider",
1818
]

src/mcp_as_a_judge/db/conversation_history_service.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
class ConversationHistoryService:
2020
"""Service for managing conversation history in judge tools."""
2121

22-
def __init__(self, config: Config, db_provider: ConversationHistoryDB | None = None):
22+
def __init__(
23+
self, config: Config, db_provider: ConversationHistoryDB | None = None
24+
):
2325
"""
2426
Initialize the conversation history service.
2527
@@ -30,7 +32,9 @@ def __init__(self, config: Config, db_provider: ConversationHistoryDB | None = N
3032
self.config = config
3133
self.db = db_provider or create_database_provider(config)
3234

33-
async def load_context_for_enrichment(self, session_id: str) -> list[ConversationRecord]:
35+
async def load_context_for_enrichment(
36+
self, session_id: str
37+
) -> list[ConversationRecord]:
3438
"""
3539
Load recent conversation records for LLM context enrichment.
3640
@@ -45,18 +49,14 @@ async def load_context_for_enrichment(self, session_id: str) -> list[Conversatio
4549
# Load recent conversations for this session
4650
recent_records = await self.db.get_session_conversations(
4751
session_id=session_id,
48-
limit=self.config.database.context_enrichment_count # load last X records
52+
limit=self.config.database.context_enrichment_count, # load last X records
4953
)
5054

5155
logger.info(f"📚 Retrieved {len(recent_records)} conversation records from DB")
5256
return recent_records
5357

5458
async def save_tool_interaction(
55-
self,
56-
session_id: str,
57-
tool_name: str,
58-
tool_input: str,
59-
tool_output: str
59+
self, session_id: str, tool_name: str, tool_input: str, tool_output: str
6060
) -> str:
6161
"""
6262
Save a tool interaction as a conversation record.
@@ -70,13 +70,15 @@ async def save_tool_interaction(
7070
Returns:
7171
ID of the created conversation record
7272
"""
73-
logger.info(f"💾 Saving tool interaction to SQLite DB for session: {session_id}, tool: {tool_name}")
73+
logger.info(
74+
f"💾 Saving tool interaction to SQLite DB for session: {session_id}, tool: {tool_name}"
75+
)
7476

7577
record_id = await self.db.save_conversation(
7678
session_id=session_id,
7779
source=tool_name,
7880
input_data=tool_input,
79-
output=tool_output
81+
output=tool_output,
8082
)
8183

8284
logger.info(f"✅ Saved conversation record with ID: {record_id}")
@@ -96,19 +98,27 @@ def format_context_for_llm(self, context_records: list[ConversationRecord]) -> s
9698
logger.info("📝 No conversation history to format for LLM context")
9799
return "No previous conversation history available."
98100

99-
logger.info(f"📝 Formatting {len(context_records)} conversation records for LLM context enrichment")
101+
logger.info(
102+
f"📝 Formatting {len(context_records)} conversation records for LLM context enrichment"
103+
)
100104

101105
context_lines = ["## Previous Conversation History"]
102-
context_lines.append("Here are the recent interactions in this session for context:")
106+
context_lines.append(
107+
"Here are the recent interactions in this session for context:"
108+
)
103109
context_lines.append("")
104110

105111
# Format records (most recent first)
106112
for i, record in enumerate(context_records, 1):
107-
context_lines.append(f"### {i}. {record.source} ({record.timestamp.strftime('%Y-%m-%d %H:%M:%S')})")
113+
context_lines.append(
114+
f"### {i}. {record.source} ({record.timestamp.strftime('%Y-%m-%d %H:%M:%S')})"
115+
)
108116
context_lines.append(f"**Input:** {record.input}")
109117
context_lines.append(f"**Output:** {record.output}")
110118
context_lines.append("")
111-
logger.info(f" Formatted record {i}: {record.source} from {record.timestamp}")
119+
logger.info(
120+
f" Formatted record {i}: {record.source} from {record.timestamp}"
121+
)
112122

113123
context_lines.append("---")
114124
context_lines.append("Use this context to make more informed decisions.")
@@ -141,18 +151,19 @@ async def get_session_summary(self, session_id: str) -> dict:
141151
"session_id": session_id,
142152
"total_interactions": len(all_records),
143153
"tool_usage": tool_counts,
144-
"latest_interaction": all_records[0].timestamp.isoformat() if all_records else None,
154+
"latest_interaction": all_records[0].timestamp.isoformat()
155+
if all_records
156+
else None,
145157
"context_enrichment_count": self.config.database.context_enrichment_count,
146-
"max_context_records": self.config.database.max_context_records
158+
"max_context_records": self.config.database.max_context_records,
147159
}
148160

149161

150162
# Convenience functions for easy integration with existing tools
151163

164+
152165
async def enrich_with_context(
153-
service: ConversationHistoryService,
154-
session_id: str,
155-
base_prompt: str
166+
service: ConversationHistoryService, session_id: str, base_prompt: str
156167
) -> str:
157168
"""
158169
Enrich a base prompt with conversation history context.
@@ -165,16 +176,17 @@ async def enrich_with_context(
165176
Returns:
166177
Enriched prompt with conversation history context
167178
"""
168-
logger.info(f"🔄 Starting context enrichment for session {session_id}, base_prompt: {base_prompt}")
179+
logger.info(
180+
f"🔄 Starting context enrichment for session {session_id}, base_prompt: {base_prompt}"
181+
)
169182

170183
context_records = await service.load_context_for_enrichment(session_id)
171184
context_text = service.format_context_for_llm(context_records)
172185

173186
enriched_prompt = f"{context_text}\n## Current Request\n{base_prompt}"
174187

175-
logger.info(f"🎯 Context enrichment completed for session {session_id}, enriched_prompt: {enriched_prompt}")
188+
logger.info(
189+
f"🎯 Context enrichment completed for session {session_id}, enriched_prompt: {enriched_prompt}"
190+
)
176191

177192
return enriched_prompt
178-
179-
180-

src/mcp_as_a_judge/db/factory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class DatabaseFactory:
1717

1818
_providers: ClassVar[dict[str, type[ConversationHistoryDB]]] = {
1919
"in_memory": SQLiteProvider, # SQLite in-memory (:memory: or empty URL)
20-
"sqlite": SQLiteProvider, # SQLite file-based storage
20+
"sqlite": SQLiteProvider, # SQLite file-based storage
2121
# Future providers can be added here:
2222
# "postgresql": PostgreSQLProvider,
2323
# "mysql": MySQLProvider,
@@ -56,14 +56,14 @@ def create_provider(cls, config: Config) -> ConversationHistoryDB:
5656
return provider_class(
5757
max_context_records=config.database.max_context_records,
5858
retention_days=config.database.record_retention_days,
59-
url=config.database.url
59+
url=config.database.url,
6060
)
6161
else:
6262
# For future network database providers (PostgreSQL, MySQL, etc.)
6363
return provider_class(
6464
url=config.database.url,
6565
max_context_records=config.database.max_context_records,
66-
retention_days=config.database.record_retention_days
66+
retention_days=config.database.record_retention_days,
6767
)
6868

6969
@classmethod
@@ -73,7 +73,9 @@ def get_available_providers(cls) -> list[str]:
7373

7474
# Not in use - option to register additional providers
7575
@classmethod
76-
def register_provider(cls, name: str, provider_class: type[ConversationHistoryDB]) -> None:
76+
def register_provider(
77+
cls, name: str, provider_class: type[ConversationHistoryDB]
78+
) -> None:
7779
"""
7880
Register a new database provider.
7981
@@ -84,7 +86,6 @@ def register_provider(cls, name: str, provider_class: type[ConversationHistoryDB
8486
cls._providers[name] = provider_class
8587

8688

87-
8889
# Convenience function
8990
def create_database_provider(config: Config) -> ConversationHistoryDB:
9091
"""

src/mcp_as_a_judge/db/interface.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@ class ConversationRecord(BaseModel):
1919
source: str # tool name
2020
input: str # tool input query
2121
output: str # tool output string
22-
timestamp: datetime = Field(default_factory=datetime.utcnow) # when the record was created
22+
timestamp: datetime = Field(
23+
default_factory=datetime.utcnow
24+
) # when the record was created
2325

2426

2527
class ConversationHistoryDB(ABC):
2628
"""Abstract interface for conversation history database operations."""
2729

2830
@abstractmethod
29-
def __init__(self, max_context_records: int = 20, retention_days: int = 1, url: str = "") -> None:
31+
def __init__(
32+
self, max_context_records: int = 20, retention_days: int = 1, url: str = ""
33+
) -> None:
3034
"""
3135
Initialize the database provider.
3236
@@ -39,11 +43,7 @@ def __init__(self, max_context_records: int = 20, retention_days: int = 1, url:
3943

4044
@abstractmethod
4145
async def save_conversation(
42-
self,
43-
session_id: str,
44-
source: str,
45-
input_data: str,
46-
output: str
46+
self, session_id: str, source: str, input_data: str, output: str
4747
) -> str:
4848
"""
4949
Save a conversation record to the database.
@@ -61,9 +61,7 @@ async def save_conversation(
6161

6262
@abstractmethod
6363
async def get_session_conversations(
64-
self,
65-
session_id: str,
66-
limit: int | None = None
64+
self, session_id: str, limit: int | None = None
6765
) -> list[ConversationRecord]:
6866
"""
6967
Retrieve all conversation records for a session.

0 commit comments

Comments
 (0)