1616)
1717
1818DbArgResolver = Callable [[str | None ], str | None ]
19+ ReconnectFn = Callable [[str | None ], bool ]
1920ResultT = TypeVar ("ResultT" )
2021
2122
@@ -24,18 +25,54 @@ class ExplorerSchemaService:
2425 session : ConnectionSession
2526 object_cache : dict [str , dict [str , Any ]]
2627 db_arg_resolver : DbArgResolver | None = None
28+ reconnect : ReconnectFn | None = None
2729
2830 def _resolve_db_arg (self , database : str | None ) -> str | None :
2931 if self .db_arg_resolver :
3032 return self .db_arg_resolver (database )
3133 return database
3234
33- def _run (self , fn : Callable [..., ResultT ], * args : Any ) -> ResultT :
34- return self .session .executor .submit (fn , * args ).result ()
35+ def _run (self , fn : Callable [[], ResultT ]) -> ResultT :
36+ return self .session .executor .submit (fn ).result ()
37+
38+ def _resolve_reconnect_db (self , database : str | None ) -> str | None :
39+ if database :
40+ return database
41+ endpoint = self .session .config .tcp_endpoint
42+ if endpoint and endpoint .database :
43+ return endpoint .database
44+ return database
45+
46+ def _reconnect (self , database : str | None ) -> bool :
47+ target_db = self ._resolve_reconnect_db (database )
48+ if self .reconnect is not None :
49+ if self .reconnect (target_db ):
50+ self .object_cache .clear ()
51+ return True
52+ return False
53+ if not self .session .config .tcp_endpoint :
54+ return False
55+ try :
56+ self .session .switch_database (target_db or "" )
57+ self .object_cache .clear ()
58+ return True
59+ except Exception :
60+ return False
61+
62+ def _run_with_retry (self , fn : Callable [[], ResultT ], database : str | None ) -> ResultT :
63+ try :
64+ return self ._run (fn )
65+ except Exception :
66+ if not self ._reconnect (database ):
67+ raise
68+ return self ._run (fn )
3569
3670 def list_databases (self ) -> list [str ]:
3771 inspector = self .session .provider .schema_inspector
38- return self ._run (inspector .get_databases , self .session .connection )
72+ return self ._run_with_retry (
73+ lambda : inspector .get_databases (self .session .connection ),
74+ None ,
75+ )
3976
4077 def list_columns (
4178 self ,
@@ -45,7 +82,10 @@ def list_columns(
4582 ) -> list [ColumnInfo ]:
4683 inspector = self .session .provider .schema_inspector
4784 db_arg = self ._resolve_db_arg (database )
48- return self ._run (inspector .get_columns , self .session .connection , name , db_arg , schema )
85+ return self ._run_with_retry (
86+ lambda : inspector .get_columns (self .session .connection , name , db_arg , schema ),
87+ database ,
88+ )
4989
5090 def list_folder_items (self , folder_type : str , database : str | None ) -> list [tuple [str , str , str ]]:
5191 inspector = self .session .provider .schema_inspector
@@ -64,36 +104,61 @@ def cached(key: str, loader: Callable[[], Any]) -> Any:
64104 return data
65105
66106 if folder_type == "tables" :
67- raw_data = cached ("tables" , lambda : self ._run (inspector .get_tables , self .session .connection , db_arg ))
107+ raw_data = cached (
108+ "tables" ,
109+ lambda : self ._run_with_retry (
110+ lambda : inspector .get_tables (self .session .connection , db_arg ),
111+ database ,
112+ ),
113+ )
68114 return [("table" , schema , name ) for schema , name in raw_data ]
69115 if folder_type == "views" :
70- raw_data = cached ("views" , lambda : self ._run (inspector .get_views , self .session .connection , db_arg ))
116+ raw_data = cached (
117+ "views" ,
118+ lambda : self ._run_with_retry (
119+ lambda : inspector .get_views (self .session .connection , db_arg ),
120+ database ,
121+ ),
122+ )
71123 return [("view" , schema , name ) for schema , name in raw_data ]
72124 if folder_type == "indexes" :
73125 if caps .supports_indexes and isinstance (inspector , IndexInspector ):
74126 return [
75127 ("index" , item .name , item .table_name )
76- for item in self ._run (inspector .get_indexes , self .session .connection , db_arg )
128+ for item in self ._run_with_retry (
129+ lambda : inspector .get_indexes (self .session .connection , db_arg ),
130+ database ,
131+ )
77132 ]
78133 return []
79134 if folder_type == "triggers" :
80135 if caps .supports_triggers and isinstance (inspector , TriggerInspector ):
81136 return [
82137 ("trigger" , item .name , item .table_name )
83- for item in self ._run (inspector .get_triggers , self .session .connection , db_arg )
138+ for item in self ._run_with_retry (
139+ lambda : inspector .get_triggers (self .session .connection , db_arg ),
140+ database ,
141+ )
84142 ]
85143 return []
86144 if folder_type == "sequences" :
87145 if caps .supports_sequences and isinstance (inspector , SequenceInspector ):
88146 return [
89147 ("sequence" , item .name , "" )
90- for item in self ._run (inspector .get_sequences , self .session .connection , db_arg )
148+ for item in self ._run_with_retry (
149+ lambda : inspector .get_sequences (self .session .connection , db_arg ),
150+ database ,
151+ )
91152 ]
92153 return []
93154 if folder_type == "procedures" :
94155 if caps .supports_stored_procedures and isinstance (inspector , ProcedureInspector ):
95156 raw_data = cached (
96- "procedures" , lambda : self ._run (inspector .get_procedures , self .session .connection , db_arg )
157+ "procedures" ,
158+ lambda : self ._run_with_retry (
159+ lambda : inspector .get_procedures (self .session .connection , db_arg ),
160+ database ,
161+ ),
97162 )
98163 return [("procedure" , "" , name ) for name in raw_data ]
99164 return []
@@ -104,18 +169,27 @@ def get_index_definition(self, database: str | None, name: str, table_name: str)
104169 if not isinstance (inspector , IndexInspector ):
105170 return None
106171 db_arg = self ._resolve_db_arg (database )
107- return self ._run (inspector .get_index_definition , self .session .connection , name , table_name , db_arg )
172+ return self ._run_with_retry (
173+ lambda : inspector .get_index_definition (self .session .connection , name , table_name , db_arg ),
174+ database ,
175+ )
108176
109177 def get_trigger_definition (self , database : str | None , name : str , table_name : str ) -> dict [str , Any ] | None :
110178 inspector = self .session .provider .schema_inspector
111179 if not isinstance (inspector , TriggerInspector ):
112180 return None
113181 db_arg = self ._resolve_db_arg (database )
114- return self ._run (inspector .get_trigger_definition , self .session .connection , name , table_name , db_arg )
182+ return self ._run_with_retry (
183+ lambda : inspector .get_trigger_definition (self .session .connection , name , table_name , db_arg ),
184+ database ,
185+ )
115186
116187 def get_sequence_definition (self , database : str | None , name : str ) -> dict [str , Any ] | None :
117188 inspector = self .session .provider .schema_inspector
118189 if not isinstance (inspector , SequenceInspector ):
119190 return None
120191 db_arg = self ._resolve_db_arg (database )
121- return self ._run (inspector .get_sequence_definition , self .session .connection , name , db_arg )
192+ return self ._run_with_retry (
193+ lambda : inspector .get_sequence_definition (self .session .connection , name , db_arg ),
194+ database ,
195+ )
0 commit comments