Skip to content

Commit f84a9d5

Browse files
committed
fix: comprehensive security and code quality improvements
Security Fixes: - Fixed SQL injection vulnerabilities in index_advisor.py (4 locations) * Converted analyze_workload query to use parameterized queries * Fixed get_existing_indexes to use parameterized queries * Fixed get_index_health duplicate/unused index queries - Fixed SQL injection vulnerabilities in hypopg_service.py (5 locations) * Used psycopg.sql.Identifier for safe table/column name escaping * Added access method validation (whitelist-based) * Converted all HypoPG function calls to parameterized queries - Added whitelist validation for ORDER BY clauses in tools_performance.py
1 parent b9eab34 commit f84a9d5

File tree

7 files changed

+226
-64
lines changed

7 files changed

+226
-64
lines changed

src/pgtuner_mcp/__init__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,31 @@
44
A Model Context Protocol (MCP) server for AI-powered PostgreSQL performance tuning.
55
"""
66

7+
from importlib.metadata import version, PackageNotFoundError
8+
79
from .server import main
810
from .__main__ import run
911

10-
__version__ = "0.1.0"
12+
13+
def _get_version() -> str:
14+
"""Get version from package metadata or fallback to pyproject.toml."""
15+
try:
16+
return version("pgtuner_mcp")
17+
except PackageNotFoundError:
18+
# Fallback: read from pyproject.toml if package is not installed
19+
try:
20+
from pathlib import Path
21+
import tomllib
22+
23+
pyproject_path = Path(__file__).parent.parent.parent / "pyproject.toml"
24+
if pyproject_path.exists():
25+
with open(pyproject_path, "rb") as f:
26+
data = tomllib.load(f)
27+
return data.get("project", {}).get("version", "0.0.0")
28+
except Exception:
29+
pass
30+
return "0.0.0"
31+
32+
33+
__version__ = _get_version()
1134
__all__ = ["main", "run"]

src/pgtuner_mcp/services/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
from .hypopg_service import HypoPGService
44
from .index_advisor import IndexAdvisor
5-
from .sql_driver import DbConnPool, RowResult, SqlDriver
5+
from .sql_driver import DbConnPool, SqlDriver
66
from .user_filter import UserFilter, get_user_filter
77

88
__all__ = [
99
"DbConnPool",
1010
"SqlDriver",
11-
"RowResult",
1211
"HypoPGService",
1312
"IndexAdvisor",
1413
"UserFilter",

src/pgtuner_mcp/services/hypopg_service.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from dataclasses import dataclass
1212
from typing import Any
1313

14+
from psycopg import sql
15+
1416
from .sql_driver import (
1517
SqlDriver,
1618
check_extension_available,
@@ -162,22 +164,43 @@ async def create_index(
162164
"""
163165
await self.ensure_available()
164166

165-
# Build the CREATE INDEX statement
166-
qualified_table = f"{schema}.{table}" if schema else table
167-
columns_str = ", ".join(columns)
168-
169-
create_stmt = f"CREATE INDEX ON {qualified_table} USING {using} ({columns_str})"
167+
# Build the CREATE INDEX statement using safe SQL composition
168+
# Use sql.Identifier for proper escaping of table and column names
169+
if schema:
170+
table_ident = sql.Identifier(schema, table)
171+
else:
172+
table_ident = sql.Identifier(table)
173+
174+
# Validate and whitelist the index access method
175+
valid_access_methods = {"btree", "hash", "brin", "bloom", "gist", "gin", "spgist"}
176+
if using.lower() not in valid_access_methods:
177+
raise ValueError(f"Invalid index access method: {using}")
178+
179+
columns_ident = sql.SQL(", ").join(sql.Identifier(col) for col in columns)
180+
181+
# Build the base CREATE INDEX statement
182+
create_stmt = sql.SQL("CREATE INDEX ON {} USING {} ({})").format(
183+
table_ident,
184+
sql.SQL(using), # validated above
185+
columns_ident
186+
)
170187

171188
if include:
172-
include_str = ", ".join(include)
173-
create_stmt += f" INCLUDE ({include_str})"
189+
include_ident = sql.SQL(", ").join(sql.Identifier(col) for col in include)
190+
create_stmt = sql.Composed([create_stmt, sql.SQL(" INCLUDE ("), include_ident, sql.SQL(")")])
174191

175192
if where:
176-
create_stmt += f" WHERE {where}"
193+
# WHERE clause is user-provided SQL expression - use as-is but document the risk
194+
# Note: The WHERE clause is intentionally passed through as the user's filter expression
195+
create_stmt = sql.Composed([create_stmt, sql.SQL(" WHERE "), sql.SQL(where)])
177196

178-
# Create the hypothetical index
197+
# Convert to string for hypopg_create_index (which expects a SQL statement as text)
198+
create_stmt_str = create_stmt.as_string()
199+
200+
# Create the hypothetical index using parameterized query
179201
result = await self.driver.execute_query(
180-
f"SELECT * FROM hypopg_create_index($${create_stmt}$$)"
202+
"SELECT * FROM hypopg_create_index(%s)",
203+
[create_stmt_str]
181204
)
182205

183206
if not result:
@@ -208,11 +231,24 @@ async def create_index_from_sql(self, create_index_sql: str) -> HypotheticalInde
208231
209232
Returns:
210233
HypotheticalIndex with the created index info
234+
235+
Note:
236+
The create_index_sql parameter is expected to be a valid CREATE INDEX
237+
statement. This method uses parameterized queries to pass the statement
238+
to hypopg_create_index(), which only processes CREATE INDEX statements
239+
and ignores any other SQL.
211240
"""
212241
await self.ensure_available()
213242

243+
# Validate that it looks like a CREATE INDEX statement
244+
normalized = create_index_sql.strip().upper()
245+
if not normalized.startswith("CREATE") or "INDEX" not in normalized:
246+
raise ValueError("Invalid CREATE INDEX statement")
247+
248+
# Use parameterized query - hypopg_create_index only processes CREATE INDEX
214249
result = await self.driver.execute_query(
215-
f"SELECT * FROM hypopg_create_index($${create_index_sql}$$)"
250+
"SELECT * FROM hypopg_create_index(%s)",
251+
[create_index_sql]
216252
)
217253

218254
if not result:
@@ -275,7 +311,8 @@ async def get_index_definition(self, indexrelid: int) -> str | None:
275311
"""
276312
try:
277313
result = await self.driver.execute_query(
278-
f"SELECT hypopg_get_indexdef({indexrelid}) as indexdef"
314+
"SELECT hypopg_get_indexdef(%s) as indexdef",
315+
[indexrelid]
279316
)
280317
if result:
281318
return result[0].get("indexdef")
@@ -295,7 +332,8 @@ async def get_index_size(self, indexrelid: int) -> int | None:
295332
"""
296333
try:
297334
result = await self.driver.execute_query(
298-
f"SELECT hypopg_relation_size({indexrelid}) as size"
335+
"SELECT hypopg_relation_size(%s) as size",
336+
[indexrelid]
299337
)
300338
if result:
301339
return result[0].get("size")
@@ -317,7 +355,8 @@ async def drop_index(self, indexrelid: int) -> bool:
317355

318356
try:
319357
await self.driver.execute_query(
320-
f"SELECT hypopg_drop_index({indexrelid})"
358+
"SELECT hypopg_drop_index(%s)",
359+
[indexrelid]
321360
)
322361
logger.info(f"Dropped hypothetical index: {indexrelid}")
323362
return True
@@ -356,7 +395,8 @@ async def hide_index(self, indexrelid: int) -> bool:
356395

357396
try:
358397
result = await self.driver.execute_query(
359-
f"SELECT hypopg_hide_index({indexrelid})"
398+
"SELECT hypopg_hide_index(%s)",
399+
[indexrelid]
360400
)
361401
if result:
362402
return result[0].get("hypopg_hide_index", False)
@@ -378,7 +418,8 @@ async def unhide_index(self, indexrelid: int) -> bool:
378418

379419
try:
380420
result = await self.driver.execute_query(
381-
f"SELECT hypopg_unhide_index({indexrelid})"
421+
"SELECT hypopg_unhide_index(%s)",
422+
[indexrelid]
382423
)
383424
if result:
384425
return result[0].get("hypopg_unhide_index", False)

0 commit comments

Comments
 (0)