Skip to content

Commit 799972f

Browse files
committed
refactor: harden input validation in persistence and command parsing
1 parent a6094f5 commit 799972f

File tree

13 files changed

+233
-9
lines changed

13 files changed

+233
-9
lines changed

src/praisonai-agents/praisonaiagents/storage/backends.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def __init__(
188188
auto_create: Create table if it doesn't exist
189189
"""
190190
self.db_path = os.path.expanduser(db_path) if db_path else str(get_storage_path())
191+
import re as _re
192+
if not isinstance(table_name, str) or not _re.match(r'^[a-zA-Z0-9_]+$', table_name):
193+
raise ValueError("table_name must contain only alphanumeric characters and underscores")
191194
self.table_name = table_name
192195
self._local = threading.local()
193196

src/praisonai/praisonai/cli/features/mcp.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
"pipx",
2525
}
2626

27+
# Per-executable argument flags that enable arbitrary inline code execution
28+
# and must be rejected to prevent command-injection via MCP command strings.
29+
_INLINE_EXEC_ARGS = {
30+
"python": {"-c", "--command"},
31+
"python3": {"-c", "--command"},
32+
"node": {"-e", "--eval", "-p", "--print"},
33+
"deno": {"-e", "--eval", "eval"},
34+
"bun": {"-e", "--eval", "eval"},
35+
}
36+
2737

2838
class MCPHandler(FlagHandler):
2939
"""
@@ -90,6 +100,21 @@ def parse_mcp_command(self, command: str, env_vars: str = None) -> Tuple[str, Li
90100
f"Command '{cmd}' is not in the allowed MCP executables list. "
91101
f"Allowed: {', '.join(sorted(ALLOWED_MCP_COMMANDS - {c for c in ALLOWED_MCP_COMMANDS if '.' in c}))}"
92102
)
103+
104+
# Reject inline-eval flags that allow arbitrary code execution for
105+
# interpreters (python -c, node -e, deno eval, bun -e, ...).
106+
base_key = basename.lower()
107+
for suffix in (".exe", ".cmd"):
108+
if base_key.endswith(suffix):
109+
base_key = base_key[: -len(suffix)]
110+
forbidden = _INLINE_EXEC_ARGS.get(base_key)
111+
if forbidden:
112+
for arg in args:
113+
if arg in forbidden:
114+
raise ValueError(
115+
f"Argument '{arg}' is not allowed for '{basename}' "
116+
"(inline code execution is blocked in MCP commands)."
117+
)
93118

94119
# Parse environment variables
95120
env = {}

src/praisonai/praisonai/persistence/conversation/async_mysql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from typing import List, Optional
1313

14-
from .base import ConversationStore, ConversationSession, ConversationMessage
14+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -62,6 +62,7 @@ def __init__(
6262
self.database = database
6363
self.user = user
6464
self.password = password
65+
validate_identifier(table_prefix, "table_prefix")
6566
self.table_prefix = table_prefix
6667
self.pool_size = pool_size
6768
self._pool = None

src/praisonai/praisonai/persistence/conversation/async_postgres.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
from typing import List, Optional
1414

15-
from .base import ConversationStore, ConversationSession, ConversationMessage
15+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -60,6 +60,7 @@ def __init__(
6060
self.database = database
6161
self.user = user
6262
self.password = password
63+
validate_identifier(table_prefix, "table_prefix")
6364
self.table_prefix = table_prefix
6465
self.pool_size = pool_size
6566
self._pool = None

src/praisonai/praisonai/persistence/conversation/async_sqlite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from typing import List, Optional
1313

14-
from .base import ConversationStore, ConversationSession, ConversationMessage
14+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -40,6 +40,7 @@ def __init__(
4040
table_prefix: Prefix for table names
4141
"""
4242
self.path = path
43+
validate_identifier(table_prefix, "table_prefix")
4344
self.table_prefix = table_prefix
4445
self._conn = None
4546
self._initialized = False

src/praisonai/praisonai/persistence/conversation/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,27 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass, field
99
from typing import Any, Dict, List, Optional
10+
import re
1011
import time
1112
import uuid
1213

1314

15+
_IDENTIFIER_RE = re.compile(r'^[a-zA-Z0-9_]*$')
16+
17+
18+
def validate_identifier(value: str, name: str = "identifier") -> str:
19+
"""Validate that a string is safe for use as a SQL identifier.
20+
21+
Only allows alphanumeric characters and underscores to prevent SQL
22+
injection in table/schema names that are interpolated into DDL/DML.
23+
"""
24+
if not isinstance(value, str) or not _IDENTIFIER_RE.match(value):
25+
raise ValueError(
26+
f"{name} must contain only alphanumeric characters and underscores"
27+
)
28+
return value
29+
30+
1431
@dataclass
1532
class ConversationMessage:
1633
"""A single message in a conversation."""

src/praisonai/praisonai/persistence/conversation/mysql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
from typing import Any, List, Optional
1111

12-
from .base import ConversationStore, ConversationSession, ConversationMessage
12+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -62,6 +62,7 @@ def __init__(
6262
)
6363

6464
self._mysql = mysql.connector
65+
validate_identifier(table_prefix, "table_prefix")
6566
self.table_prefix = table_prefix
6667
self.sessions_table = f"{table_prefix}sessions"
6768
self.messages_table = f"{table_prefix}messages"

src/praisonai/praisonai/persistence/conversation/postgres.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Callable, Dict, List, Optional
1212
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
1313

14-
from .base import ConversationStore, ConversationSession, ConversationMessage
14+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -85,6 +85,8 @@ def __init__(
8585
self._psycopg2 = psycopg2
8686
self._RealDictCursor = RealDictCursor
8787

88+
validate_identifier(schema, "schema")
89+
validate_identifier(table_prefix, "table_prefix")
8890
self.schema = schema
8991
self.table_prefix = table_prefix
9092
self.sessions_table = f"{schema}.{table_prefix}sessions"

src/praisonai/praisonai/persistence/conversation/singlestore.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
from typing import Any, List, Optional
1111

12-
from .base import ConversationStore, ConversationSession, ConversationMessage
12+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -48,6 +48,7 @@ def __init__(
4848
)
4949

5050
self._s2 = s2
51+
validate_identifier(table_prefix, "table_prefix")
5152
self.table_prefix = table_prefix
5253
self.sessions_table = f"{table_prefix}sessions"
5354
self.messages_table = f"{table_prefix}messages"

src/praisonai/praisonai/persistence/conversation/supabase.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
from typing import List, Optional
1111

12-
from .base import ConversationStore, ConversationSession, ConversationMessage
12+
from .base import ConversationStore, ConversationSession, ConversationMessage, validate_identifier
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -49,6 +49,7 @@ def __init__(
4949
max_retries: Max retries for paused project wake-up
5050
retry_delay: Base delay between retries in seconds
5151
"""
52+
validate_identifier(table_prefix, "table_prefix")
5253
try:
5354
from supabase import create_client, Client
5455
except ImportError:

0 commit comments

Comments
 (0)