1+ """
2+ 数据库范围检查器
3+ 用于检测和限制SQL查询中的跨数据库访问
4+ """
5+ import re
6+ import logging
7+ from typing import Set , Optional , List , Tuple
8+ from enum import Enum
9+
10+ logger = logging .getLogger ("mysql_server" )
11+
12+ class DatabaseAccessLevel (Enum ):
13+ """数据库访问级别"""
14+ STRICT = "strict" # 严格模式:只能访问指定数据库
15+ RESTRICTED = "restricted" # 限制模式:允许访问指定数据库和系统库
16+ PERMISSIVE = "permissive" # 宽松模式:允许访问所有数据库(默认)
17+
18+ class DatabaseScopeViolation (Exception ):
19+ """数据库范围违规异常"""
20+ pass
21+
22+ class DatabaseScopeChecker :
23+ """数据库范围检查器"""
24+
25+ # 系统数据库列表
26+ SYSTEM_DATABASES = {
27+ 'information_schema' ,
28+ 'mysql' ,
29+ 'performance_schema' ,
30+ 'sys'
31+ }
32+
33+ # 跨数据库查询模式
34+ CROSS_DB_PATTERNS = [
35+ # database.table 格式
36+ r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
37+ # SHOW TABLES FROM database
38+ r'\bSHOW\s+(?:FULL\s+)?TABLES\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
39+ # USE database
40+ r'\bUSE\s+([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
41+ # SELECT ... FROM database.table
42+ r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
43+ # JOIN database.table
44+ r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
45+ # INSERT INTO database.table
46+ r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
47+ # UPDATE database.table
48+ r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
49+ # DELETE FROM database.table
50+ r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\.\s*([a-zA-Z_][a-zA-Z0-9_]*)\b' ,
51+ ]
52+
53+ def __init__ (self , allowed_database : Optional [str ] = None ,
54+ access_level : DatabaseAccessLevel = DatabaseAccessLevel .PERMISSIVE ):
55+ """
56+ 初始化数据库范围检查器
57+
58+ Args:
59+ allowed_database: 允许访问的数据库名称
60+ access_level: 访问级别
61+ """
62+ self .allowed_database = allowed_database
63+ self .access_level = access_level
64+ self .is_enabled = allowed_database is not None and access_level != DatabaseAccessLevel .PERMISSIVE
65+
66+ logger .debug (f"数据库范围检查器初始化: 允许数据库={ allowed_database } , 访问级别={ access_level .value } , 启用={ self .is_enabled } " )
67+
68+ def check_query (self , sql_query : str ) -> Tuple [bool , List [str ]]:
69+ """
70+ 检查SQL查询是否违反数据库范围限制
71+
72+ Args:
73+ sql_query: SQL查询语句
74+
75+ Returns:
76+ (是否允许, 违规详情列表)
77+ """
78+ if not self .is_enabled :
79+ return True , []
80+
81+ violations = []
82+
83+ # 提取查询中涉及的数据库
84+ referenced_databases = self ._extract_databases (sql_query )
85+
86+ for db_name in referenced_databases :
87+ if not self ._is_database_allowed (db_name ):
88+ violations .append (f"不允许访问数据库: { db_name } " )
89+
90+ # 检查特殊查询类型
91+ special_violations = self ._check_special_queries (sql_query )
92+ violations .extend (special_violations )
93+
94+ is_allowed = len (violations ) == 0
95+
96+ if violations :
97+ logger .warning (f"数据库范围检查失败: { violations } " )
98+
99+ return is_allowed , violations
100+
101+ def _extract_databases (self , sql_query : str ) -> Set [str ]:
102+ """提取SQL查询中涉及的数据库名称"""
103+ databases = set ()
104+
105+ # 标准化SQL(转换为大写,去除多余空格)
106+ normalized_sql = re .sub (r'\s+' , ' ' , sql_query .upper ().strip ())
107+
108+ for pattern in self .CROSS_DB_PATTERNS :
109+ matches = re .finditer (pattern , normalized_sql , re .IGNORECASE )
110+ for match in matches :
111+ # 第一个捕获组通常是数据库名
112+ if match .groups ():
113+ db_name = match .group (1 ).lower ()
114+ # 过滤掉非数据库名的匹配(如函数名等)
115+ if self ._is_valid_database_name (db_name ):
116+ databases .add (db_name )
117+
118+ return databases
119+
120+ def _is_valid_database_name (self , name : str ) -> bool :
121+ """检查是否是有效的数据库名称"""
122+ # 数据库名称规则:字母、数字、下划线,不能以数字开头
123+ return bool (re .match (r'^[a-zA-Z_][a-zA-Z0-9_]*$' , name ))
124+
125+ def _is_database_allowed (self , db_name : str ) -> bool :
126+ """检查数据库是否被允许访问"""
127+ db_name_lower = db_name .lower ()
128+
129+ # 检查是否是允许的主数据库
130+ if self .allowed_database and db_name_lower == self .allowed_database .lower ():
131+ return True
132+
133+ # 根据访问级别决定是否允许系统数据库
134+ if self .access_level == DatabaseAccessLevel .RESTRICTED :
135+ if db_name_lower in self .SYSTEM_DATABASES :
136+ return True
137+
138+ return False
139+
140+ def _check_special_queries (self , sql_query : str ) -> List [str ]:
141+ """检查特殊类型的查询"""
142+ violations = []
143+ normalized_sql = sql_query .upper ().strip ()
144+
145+ # 检查SHOW DATABASES查询
146+ if re .search (r'\bSHOW\s+DATABASES\b' , normalized_sql ):
147+ if self .access_level == DatabaseAccessLevel .STRICT :
148+ violations .append ("严格模式下不允许执行 SHOW DATABASES" )
149+
150+ # 检查USE语句
151+ if re .search (r'\bUSE\s+' , normalized_sql ):
152+ violations .append ("不允许使用 USE 语句切换数据库" )
153+
154+ # 检查系统表访问
155+ system_table_patterns = [
156+ r'\bmysql\.user\b' ,
157+ r'\bmysql\.db\b' ,
158+ r'\binformation_schema\.' ,
159+ r'\bperformance_schema\.' ,
160+ r'\bsys\.'
161+ ]
162+
163+ for pattern in system_table_patterns :
164+ if re .search (pattern , normalized_sql , re .IGNORECASE ):
165+ if self .access_level == DatabaseAccessLevel .STRICT :
166+ violations .append (f"严格模式下不允许访问系统表" )
167+ break
168+
169+ return violations
170+
171+ def get_allowed_databases (self ) -> Set [str ]:
172+ """获取允许访问的数据库列表"""
173+ allowed = set ()
174+
175+ if self .allowed_database :
176+ allowed .add (self .allowed_database .lower ())
177+
178+ if self .access_level == DatabaseAccessLevel .RESTRICTED :
179+ allowed .update (self .SYSTEM_DATABASES )
180+
181+ return allowed
182+
183+ def is_cross_database_query (self , sql_query : str ) -> bool :
184+ """检查是否是跨数据库查询"""
185+ referenced_dbs = self ._extract_databases (sql_query )
186+ return len (referenced_dbs ) > 0
187+
188+ # 便捷函数
189+ def create_database_checker (allowed_database : Optional [str ] = None ,
190+ access_level : str = "permissive" ) -> DatabaseScopeChecker :
191+ """
192+ 创建数据库范围检查器的便捷函数
193+
194+ Args:
195+ allowed_database: 允许访问的数据库名称
196+ access_level: 访问级别字符串 (strict/restricted/permissive)
197+
198+ Returns:
199+ DatabaseScopeChecker实例
200+ """
201+ try :
202+ level = DatabaseAccessLevel (access_level .lower ())
203+ except ValueError :
204+ logger .warning (f"无效的访问级别: { access_level } ,使用默认的 permissive" )
205+ level = DatabaseAccessLevel .PERMISSIVE
206+
207+ return DatabaseScopeChecker (allowed_database , level )
0 commit comments