44"""
55
66import json
7+ import logging
8+ import os
79import sqlite3
810from contextlib import contextmanager
911from datetime import datetime
1012from pathlib import Path
1113from typing import Any
1214from uuid import UUID , uuid4
1315
14- # 元数据库文件路径
15- METADATA_DB_PATH = Path (__file__ ).parent .parent .parent .parent / "data" / "metadata.db"
16+ logger = logging .getLogger (__name__ )
17+
18+ # 允许更新的字段白名单(防止 SQL 注入)
19+ ALLOWED_UPDATE_FIELDS = {
20+ "name" ,
21+ "layout_data" ,
22+ "visible_tables" ,
23+ "is_default" ,
24+ "zoom" ,
25+ "viewport_x" ,
26+ "viewport_y" ,
27+ }
28+
29+ # 元数据库文件路径(支持环境变量覆盖)
30+ METADATA_DB_PATH = Path (
31+ os .getenv (
32+ "METADATA_DB_PATH" , str (Path (__file__ ).parent .parent .parent .parent / "data" / "metadata.db" )
33+ )
34+ )
1635
1736
1837def _ensure_db_dir ():
@@ -34,6 +53,10 @@ def get_metadata_db():
3453 try :
3554 yield conn
3655 conn .commit ()
56+ except Exception as e :
57+ conn .rollback ()
58+ logger .error (f"Database error: { e } " )
59+ raise
3760 finally :
3861 conn .close ()
3962
@@ -68,6 +91,18 @@ def init_metadata_db():
6891 ON schema_layouts(user_id, connection_id)
6992 """ )
7093
94+ # 为默认布局查询创建索引
95+ cursor .execute ("""
96+ CREATE INDEX IF NOT EXISTS idx_layouts_default
97+ ON schema_layouts(user_id, connection_id, is_default)
98+ """ )
99+
100+ # 为单个布局查询创建索引
101+ cursor .execute ("""
102+ CREATE INDEX IF NOT EXISTS idx_layouts_id_user
103+ ON schema_layouts(id, user_id)
104+ """ )
105+
71106 conn .commit ()
72107
73108
@@ -93,6 +128,26 @@ def list_layouts(user_id: UUID, connection_id: UUID) -> list[dict]:
93128 )
94129 return cursor .fetchall ()
95130
131+ @staticmethod
132+ def _parse_json_fields (row : dict ) -> dict :
133+ """安全解析 JSON 字段"""
134+ if not row :
135+ return row
136+ try :
137+ row ["layout_data" ] = json .loads (row ["layout_data" ] or "{}" )
138+ except (json .JSONDecodeError , TypeError ) as e :
139+ logger .warning (f"Failed to parse layout_data: { e } " )
140+ row ["layout_data" ] = {}
141+ try :
142+ row ["visible_tables" ] = (
143+ json .loads (row ["visible_tables" ]) if row ["visible_tables" ] else None
144+ )
145+ except (json .JSONDecodeError , TypeError ) as e :
146+ logger .warning (f"Failed to parse visible_tables: { e } " )
147+ row ["visible_tables" ] = None
148+ row ["is_default" ] = bool (row .get ("is_default" , 0 ))
149+ return row
150+
96151 @staticmethod
97152 def get_layout (layout_id : UUID , user_id : UUID ) -> dict | None :
98153 """获取单个布局"""
@@ -106,14 +161,7 @@ def get_layout(layout_id: UUID, user_id: UUID) -> dict | None:
106161 (str (layout_id ), str (user_id )),
107162 )
108163 row = cursor .fetchone ()
109- if row :
110- # 解析 JSON 字段
111- row ["layout_data" ] = json .loads (row ["layout_data" ] or "{}" )
112- row ["visible_tables" ] = (
113- json .loads (row ["visible_tables" ]) if row ["visible_tables" ] else None
114- )
115- row ["is_default" ] = bool (row ["is_default" ])
116- return row
164+ return LayoutRepository ._parse_json_fields (row ) if row else None
117165
118166 @staticmethod
119167 def get_default_layout (user_id : UUID , connection_id : UUID ) -> dict | None :
@@ -128,13 +176,7 @@ def get_default_layout(user_id: UUID, connection_id: UUID) -> dict | None:
128176 (str (user_id ), str (connection_id )),
129177 )
130178 row = cursor .fetchone ()
131- if row :
132- row ["layout_data" ] = json .loads (row ["layout_data" ] or "{}" )
133- row ["visible_tables" ] = (
134- json .loads (row ["visible_tables" ]) if row ["visible_tables" ] else None
135- )
136- row ["is_default" ] = bool (row ["is_default" ])
137- return row
179+ return LayoutRepository ._parse_json_fields (row ) if row else None
138180
139181 @staticmethod
140182 def create_layout (
@@ -190,12 +232,16 @@ def update_layout(
190232 connection_id : UUID | None = None ,
191233 ** kwargs ,
192234 ) -> dict | None :
193- """更新布局"""
194- # 构建更新字段
235+ """更新布局(使用白名单防止 SQL 注入) """
236+ # 构建更新字段(只允许白名单中的字段)
195237 updates = []
196238 values = []
197239
198240 for key , value in kwargs .items ():
241+ # 安全检查:只允许白名单中的字段
242+ if key not in ALLOWED_UPDATE_FIELDS :
243+ logger .warning (f"Attempted to update disallowed field: { key } " )
244+ continue
199245 if value is not None :
200246 if key == "layout_data" :
201247 updates .append ("layout_data = ?" )
@@ -206,7 +252,8 @@ def update_layout(
206252 elif key == "is_default" :
207253 updates .append ("is_default = ?" )
208254 values .append (1 if value else 0 )
209- else :
255+ elif key in ("name" , "zoom" , "viewport_x" , "viewport_y" ):
256+ # 这些字段直接使用参数化查询
210257 updates .append (f"{ key } = ?" )
211258 values .append (value )
212259
0 commit comments