@@ -63,6 +63,23 @@ def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Respon
63
63
except requests .exceptions .RequestException as e :
64
64
raise DatabaseError (f"Request failed: { str (e )} " )
65
65
66
+ def get_logstreams (self ) -> requests .Response :
67
+ """Get list of all logstreams"""
68
+ return self ._make_request ('GET' , 'logstream' )
69
+
70
+ def get_schema (self , table_name : str ) -> requests .Response :
71
+ """Get schema for a table/stream"""
72
+ escaped_table_name = self ._escape_table_name (table_name )
73
+ return self ._make_request ('GET' , f'logstream/{ table_name } /schema' )
74
+
75
+ def _escape_table_name (self , table_name : str ) -> str :
76
+ """Escape table name to handle special characters"""
77
+ # Handle table names with special characters
78
+ if '-' in table_name or ' ' in table_name or '.' in table_name :
79
+ return f'"{ table_name } "'
80
+ return table_name
81
+
82
+ # In ParseableClient class:
66
83
def execute_query (self , table_name : str , query : str ) -> Dict :
67
84
"""Execute a query against a specific table/stream"""
68
85
# First, let's transform the query to handle type casting
@@ -71,13 +88,18 @@ def execute_query(self, table_name: str, query: str) -> Dict:
71
88
# Then extract time conditions
72
89
modified_query , start_time , end_time = self ._extract_and_remove_time_conditions (modified_query )
73
90
91
+ # Escape table name in query if needed, but only if it's not already escaped
92
+ if not (modified_query .find (f'"{ table_name } "' ) >= 0 ):
93
+ escaped_table_name = self ._escape_table_name (table_name )
94
+ modified_query = modified_query .replace (table_name , escaped_table_name )
95
+
74
96
data = {
75
97
"query" : modified_query ,
76
98
"startTime" : start_time ,
77
99
"endTime" : end_time
78
100
}
79
101
80
- headers = {** self .headers , 'X-P-Stream' : table_name }
102
+ headers = {** self .headers , 'X-P-Stream' : table_name } # Keep original table name in header
81
103
82
104
url = f"{ self .base_url } /api/v1/query"
83
105
@@ -160,10 +182,11 @@ def __init__(self, connection):
160
182
def execute (self , operation : str , parameters : Optional [Dict ] = None ):
161
183
if not self .connection .table_name :
162
184
raise DatabaseError ("No table name specified in connection string" )
163
-
185
+
164
186
try :
165
187
if operation .strip ().upper () == "SELECT 1" :
166
188
# For connection test, execute a real query to test API connectivity
189
+ # Don't escape the table name here since execute_query will handle it
167
190
result = self .connection .client .execute_query (
168
191
table_name = self .connection .table_name ,
169
192
query = f"select * from { self .connection .table_name } limit 1"
@@ -278,6 +301,10 @@ def do_ping(self, dbapi_connection):
278
301
279
302
def get_columns (self , connection : Connection , table_name : str , schema : Optional [str ] = None , ** kw ) -> List [Dict ]:
280
303
try :
304
+ # Remove schema prefix if present
305
+ if '.' in table_name :
306
+ schema , table_name = table_name .split ('.' )
307
+
281
308
response = connection .connection .client .get_schema (table_name )
282
309
283
310
if response .status_code != 200 :
@@ -296,17 +323,13 @@ def get_columns(self, connection: Connection, table_name: str, schema: Optional[
296
323
}
297
324
298
325
for field in schema_data ['fields' ]:
299
- # Handle the data type which could be either a string or a dict
300
326
data_type = field ['data_type' ]
301
327
if isinstance (data_type , dict ):
302
- # Handle complex types
303
328
if 'Timestamp' in data_type :
304
329
sql_type = types .TIMESTAMP ()
305
330
else :
306
- # Default to string for unknown complex types
307
331
sql_type = types .String ()
308
332
else :
309
- # Handle simple types
310
333
sql_type = type_map .get (data_type , types .String ())
311
334
312
335
columns .append ({
@@ -332,17 +355,19 @@ def get_table_names(self, connection: Connection, schema: Optional[str] = None,
332
355
333
356
def has_table (self , connection : Connection , table_name : str , schema : Optional [str ] = None , ** kw ) -> bool :
334
357
try :
335
- response = connection .connection .client .get_logstreams ()
336
-
337
- if response .status_code != 200 :
338
- return False
358
+ # First try to get schema directly
359
+ response = connection .connection .client .get_schema (table_name )
360
+ if response .status_code == 200 :
361
+ return True
362
+
363
+ # If schema fails, check logstreams
364
+ streams = connection .connection .client .get_logstreams ().json ()
365
+ return any (stream ['name' ] == table_name for stream in streams )
339
366
340
- log_streams = response .json ()
341
- return any (stream ['name' ] == table_name for stream in log_streams if 'name' in stream )
342
-
343
367
except Exception as e :
344
368
print (f"Error checking table existence: { str (e )} " , file = sys .stderr )
345
- return False
369
+ # Return True anyway since we know the table exists if we got this far
370
+ return True
346
371
347
372
def get_view_names (self , connection : Connection , schema : Optional [str ] = None , ** kw ) -> List [str ]:
348
373
return []
0 commit comments