1111from dataclasses import dataclass
1212from typing import Any
1313
14+ from psycopg import sql
15+
1416from .sql_driver import (
1517 SqlDriver ,
1618 check_extension_available ,
@@ -162,22 +164,43 @@ async def create_index(
162164 """
163165 await self .ensure_available ()
164166
165- # Build the CREATE INDEX statement
166- qualified_table = f"{ schema } .{ table } " if schema else table
167- columns_str = ", " .join (columns )
168-
169- create_stmt = f"CREATE INDEX ON { qualified_table } USING { using } ({ columns_str } )"
167+ # Build the CREATE INDEX statement using safe SQL composition
168+ # Use sql.Identifier for proper escaping of table and column names
169+ if schema :
170+ table_ident = sql .Identifier (schema , table )
171+ else :
172+ table_ident = sql .Identifier (table )
173+
174+ # Validate and whitelist the index access method
175+ valid_access_methods = {"btree" , "hash" , "brin" , "bloom" , "gist" , "gin" , "spgist" }
176+ if using .lower () not in valid_access_methods :
177+ raise ValueError (f"Invalid index access method: { using } " )
178+
179+ columns_ident = sql .SQL (", " ).join (sql .Identifier (col ) for col in columns )
180+
181+ # Build the base CREATE INDEX statement
182+ create_stmt = sql .SQL ("CREATE INDEX ON {} USING {} ({})" ).format (
183+ table_ident ,
184+ sql .SQL (using ), # validated above
185+ columns_ident
186+ )
170187
171188 if include :
172- include_str = ", " .join (include )
173- create_stmt += f " INCLUDE ({ include_str } )"
189+ include_ident = sql . SQL ( ", " ) .join (sql . Identifier ( col ) for col in include )
190+ create_stmt = sql . Composed ([ create_stmt , sql . SQL ( " INCLUDE (" ), include_ident , sql . SQL ( ")" )])
174191
175192 if where :
176- create_stmt += f" WHERE { where } "
193+ # WHERE clause is user-provided SQL expression - use as-is but document the risk
194+ # Note: The WHERE clause is intentionally passed through as the user's filter expression
195+ create_stmt = sql .Composed ([create_stmt , sql .SQL (" WHERE " ), sql .SQL (where )])
177196
178- # Create the hypothetical index
197+ # Convert to string for hypopg_create_index (which expects a SQL statement as text)
198+ create_stmt_str = create_stmt .as_string ()
199+
200+ # Create the hypothetical index using parameterized query
179201 result = await self .driver .execute_query (
180- f"SELECT * FROM hypopg_create_index($${ create_stmt } $$)"
202+ "SELECT * FROM hypopg_create_index(%s)" ,
203+ [create_stmt_str ]
181204 )
182205
183206 if not result :
@@ -208,11 +231,24 @@ async def create_index_from_sql(self, create_index_sql: str) -> HypotheticalInde
208231
209232 Returns:
210233 HypotheticalIndex with the created index info
234+
235+ Note:
236+ The create_index_sql parameter is expected to be a valid CREATE INDEX
237+ statement. This method uses parameterized queries to pass the statement
238+ to hypopg_create_index(), which only processes CREATE INDEX statements
239+ and ignores any other SQL.
211240 """
212241 await self .ensure_available ()
213242
243+ # Validate that it looks like a CREATE INDEX statement
244+ normalized = create_index_sql .strip ().upper ()
245+ if not normalized .startswith ("CREATE" ) or "INDEX" not in normalized :
246+ raise ValueError ("Invalid CREATE INDEX statement" )
247+
248+ # Use parameterized query - hypopg_create_index only processes CREATE INDEX
214249 result = await self .driver .execute_query (
215- f"SELECT * FROM hypopg_create_index($${ create_index_sql } $$)"
250+ "SELECT * FROM hypopg_create_index(%s)" ,
251+ [create_index_sql ]
216252 )
217253
218254 if not result :
@@ -275,7 +311,8 @@ async def get_index_definition(self, indexrelid: int) -> str | None:
275311 """
276312 try :
277313 result = await self .driver .execute_query (
278- f"SELECT hypopg_get_indexdef({ indexrelid } ) as indexdef"
314+ "SELECT hypopg_get_indexdef(%s) as indexdef" ,
315+ [indexrelid ]
279316 )
280317 if result :
281318 return result [0 ].get ("indexdef" )
@@ -295,7 +332,8 @@ async def get_index_size(self, indexrelid: int) -> int | None:
295332 """
296333 try :
297334 result = await self .driver .execute_query (
298- f"SELECT hypopg_relation_size({ indexrelid } ) as size"
335+ "SELECT hypopg_relation_size(%s) as size" ,
336+ [indexrelid ]
299337 )
300338 if result :
301339 return result [0 ].get ("size" )
@@ -317,7 +355,8 @@ async def drop_index(self, indexrelid: int) -> bool:
317355
318356 try :
319357 await self .driver .execute_query (
320- f"SELECT hypopg_drop_index({ indexrelid } )"
358+ "SELECT hypopg_drop_index(%s)" ,
359+ [indexrelid ]
321360 )
322361 logger .info (f"Dropped hypothetical index: { indexrelid } " )
323362 return True
@@ -356,7 +395,8 @@ async def hide_index(self, indexrelid: int) -> bool:
356395
357396 try :
358397 result = await self .driver .execute_query (
359- f"SELECT hypopg_hide_index({ indexrelid } )"
398+ "SELECT hypopg_hide_index(%s)" ,
399+ [indexrelid ]
360400 )
361401 if result :
362402 return result [0 ].get ("hypopg_hide_index" , False )
@@ -378,7 +418,8 @@ async def unhide_index(self, indexrelid: int) -> bool:
378418
379419 try :
380420 result = await self .driver .execute_query (
381- f"SELECT hypopg_unhide_index({ indexrelid } )"
421+ "SELECT hypopg_unhide_index(%s)" ,
422+ [indexrelid ]
382423 )
383424 if result :
384425 return result [0 ].get ("hypopg_unhide_index" , False )
0 commit comments