88from sqlit .shared .ui .protocols import ResultsMixinHost
99from sqlit .shared .ui .widgets import SqlitDataTable
1010
11+ MIN_TIMER_DELAY_S = 0.001
12+
1113
1214class ResultsMixin :
1315 """Mixin providing results handling functionality."""
@@ -19,6 +21,122 @@ class ResultsMixin:
1921 _tooltip_showing : bool = False
2022 _tooltip_timer : Any | None = None
2123
24+ def _schedule_results_timer (self : ResultsMixinHost , delay_s : float , callback : Any ) -> Any | None :
25+ set_timer = getattr (self , "set_timer" , None )
26+ if callable (set_timer ):
27+ return set_timer (delay_s , callback )
28+ call_later = getattr (self , "call_later" , None )
29+ if callable (call_later ):
30+ try :
31+ call_later (callback )
32+ return None
33+ except Exception :
34+ pass
35+ try :
36+ callback ()
37+ except Exception :
38+ pass
39+ return None
40+
41+ def _apply_result_table_columns (
42+ self : ResultsMixinHost ,
43+ table_info : dict [str , Any ],
44+ token : int ,
45+ columns : list [Any ],
46+ ) -> None :
47+ if table_info .get ("_columns_token" ) != token :
48+ return
49+ table_info ["columns" ] = columns
50+
51+ def _prime_result_table_columns (self : ResultsMixinHost , table_info : dict [str , Any ] | None ) -> None :
52+ if not table_info :
53+ return
54+ if table_info .get ("columns" ):
55+ return
56+ name = table_info .get ("name" )
57+ if not name :
58+ return
59+ database = table_info .get ("database" )
60+ schema = table_info .get ("schema" )
61+ token = int (table_info .get ("_columns_token" , 0 )) + 1
62+ table_info ["_columns_token" ] = token
63+
64+ async def work_async () -> None :
65+ import asyncio
66+
67+ columns : list [Any ] = []
68+ try :
69+ runtime = getattr (self .services , "runtime" , None )
70+ use_worker = bool (getattr (runtime , "process_worker" , False )) and not bool (
71+ getattr (getattr (runtime , "mock" , None ), "enabled" , False )
72+ )
73+ client = None
74+ if use_worker and hasattr (self , "_get_process_worker_client_async" ):
75+ client = await self ._get_process_worker_client_async () # type: ignore[attr-defined]
76+
77+ if client is not None and hasattr (client , "list_columns" ) and self .current_config is not None :
78+ outcome = await asyncio .to_thread (
79+ client .list_columns ,
80+ config = self .current_config ,
81+ database = database ,
82+ schema = schema ,
83+ name = name ,
84+ )
85+ if getattr (outcome , "cancelled" , False ):
86+ return
87+ error = getattr (outcome , "error" , None )
88+ if error :
89+ raise RuntimeError (error )
90+ columns = outcome .columns or []
91+ else :
92+ schema_service = getattr (self , "_get_schema_service" , None )
93+ if callable (schema_service ):
94+ service = self ._get_schema_service ()
95+ if service :
96+ columns = await asyncio .to_thread (
97+ service .list_columns ,
98+ database ,
99+ schema ,
100+ name ,
101+ )
102+ except Exception :
103+ columns = []
104+
105+ self ._schedule_results_timer (
106+ MIN_TIMER_DELAY_S ,
107+ lambda : self ._apply_result_table_columns (table_info , token , columns ),
108+ )
109+
110+ self .run_worker (work_async (), name = f"prime-result-columns-{ name } " , exclusive = False )
111+
112+ def _normalize_column_name (self : ResultsMixinHost , name : str ) -> str :
113+ trimmed = name .strip ()
114+ if len (trimmed ) >= 2 :
115+ if trimmed [0 ] == trimmed [- 1 ] and trimmed [0 ] in ("\" " , "`" ):
116+ trimmed = trimmed [1 :- 1 ]
117+ elif trimmed [0 ] == "[" and trimmed [- 1 ] == "]" :
118+ trimmed = trimmed [1 :- 1 ]
119+ if "." in trimmed and not any (q in trimmed for q in ("\" " , "`" , "[" )):
120+ trimmed = trimmed .split ("." )[- 1 ]
121+ return trimmed .lower ()
122+
123+ def _get_active_results_table_info (
124+ self : ResultsMixinHost ,
125+ table : SqlitDataTable | None ,
126+ stacked : bool ,
127+ ) -> dict [str , Any ] | None :
128+ if not table :
129+ return None
130+ if stacked :
131+ section = self ._find_results_section (table )
132+ table_info = getattr (section , "result_table_info" , None )
133+ if table_info :
134+ return table_info
135+ table_info = getattr (table , "result_table_info" , None )
136+ if table_info :
137+ return table_info
138+ return getattr (self , "_last_query_table" , None )
139+
22140 def _copy_text (self : ResultsMixinHost , text : str ) -> bool :
23141 """Copy text to clipboard if possible, otherwise store internally."""
24142 self ._internal_clipboard = text
@@ -610,21 +728,20 @@ def sql_value(v: object) -> str:
610728 # Get table name and primary key columns
611729 table_name = "<table>"
612730 pk_column_names : set [str ] = set ()
613-
614- if hasattr (self , "_last_query_table" ) and self ._last_query_table :
615- table_info = self ._last_query_table
616- table_name = table_info ["name" ]
731+ table_info = self ._get_active_results_table_info (table , _stacked )
732+ if table_info :
733+ table_name = table_info .get ("name" , table_name )
617734 # Get PK columns from column info
618735 for col in table_info .get ("columns" , []):
619736 if col .is_primary_key :
620- pk_column_names .add (col .name )
737+ pk_column_names .add (self . _normalize_column_name ( col .name ) )
621738
622739 # Build WHERE clause - prefer PK columns, fall back to all columns
623740 where_parts = []
624741 for i , col in enumerate (columns ):
625742 if i < len (row_values ):
626743 # If we have PK info, only use PK columns; otherwise use all columns
627- if pk_column_names and col not in pk_column_names :
744+ if pk_column_names and self . _normalize_column_name ( col ) not in pk_column_names :
628745 continue
629746 val = row_values [i ]
630747 if val is None :
@@ -685,9 +802,10 @@ def action_edit_cell(self: ResultsMixinHost) -> None:
685802 column_name = columns [cursor_col ]
686803
687804 # Check if this column is a primary key - don't allow editing PKs
688- if hasattr (self , "_last_query_table" ) and self ._last_query_table :
689- for col in self ._last_query_table .get ("columns" , []):
690- if col .name == column_name and col .is_primary_key :
805+ table_info = self ._get_active_results_table_info (table , _stacked )
806+ if table_info :
807+ for col in table_info .get ("columns" , []):
808+ if col .is_primary_key and self ._normalize_column_name (col .name ) == self ._normalize_column_name (column_name ):
691809 self .notify ("Cannot edit primary key column" , severity = "warning" )
692810 return
693811
@@ -705,21 +823,19 @@ def sql_value(v: object) -> str:
705823 # Get table name and primary key columns
706824 table_name = "<table>"
707825 pk_column_names : set [str ] = set ()
708-
709- if hasattr (self , "_last_query_table" ) and self ._last_query_table :
710- table_info = self ._last_query_table
711- table_name = table_info ["name" ]
826+ if table_info :
827+ table_name = table_info .get ("name" , table_name )
712828 # Get PK columns from column info
713829 for col in table_info .get ("columns" , []):
714830 if col .is_primary_key :
715- pk_column_names .add (col .name )
831+ pk_column_names .add (self . _normalize_column_name ( col .name ) )
716832
717833 # Build WHERE clause - prefer PK columns, fall back to all columns
718834 where_parts = []
719835 for i , col in enumerate (columns ):
720836 if i < len (row_values ):
721837 # If we have PK info, only use PK columns; otherwise use all columns
722- if pk_column_names and col not in pk_column_names :
838+ if pk_column_names and self . _normalize_column_name ( col ) not in pk_column_names :
723839 continue
724840 val = row_values [i ]
725841 if val is None :
0 commit comments