Skip to content

Commit 9520fd1

Browse files
committed
update handling escapes for table name
1 parent 47e5b22 commit 9520fd1

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

parseable_connector/__init__.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Respon
6363
except requests.exceptions.RequestException as e:
6464
raise DatabaseError(f"Request failed: {str(e)}")
6565

66+
def get_logstreams(self) -> requests.Response:
67+
"""Get list of all logstreams"""
68+
return self._make_request('GET', 'logstream')
69+
70+
def get_schema(self, table_name: str) -> requests.Response:
71+
"""Get schema for a table/stream"""
72+
escaped_table_name = self._escape_table_name(table_name)
73+
return self._make_request('GET', f'logstream/{table_name}/schema')
74+
75+
def _escape_table_name(self, table_name: str) -> str:
76+
"""Escape table name to handle special characters"""
77+
# Handle table names with special characters
78+
if '-' in table_name or ' ' in table_name or '.' in table_name:
79+
return f'"{table_name}"'
80+
return table_name
81+
82+
# In ParseableClient class:
6683
def execute_query(self, table_name: str, query: str) -> Dict:
6784
"""Execute a query against a specific table/stream"""
6885
# First, let's transform the query to handle type casting
@@ -71,13 +88,18 @@ def execute_query(self, table_name: str, query: str) -> Dict:
7188
# Then extract time conditions
7289
modified_query, start_time, end_time = self._extract_and_remove_time_conditions(modified_query)
7390

91+
# Escape table name in query if needed, but only if it's not already escaped
92+
if not (modified_query.find(f'"{table_name}"') >= 0):
93+
escaped_table_name = self._escape_table_name(table_name)
94+
modified_query = modified_query.replace(table_name, escaped_table_name)
95+
7496
data = {
7597
"query": modified_query,
7698
"startTime": start_time,
7799
"endTime": end_time
78100
}
79101

80-
headers = {**self.headers, 'X-P-Stream': table_name}
102+
headers = {**self.headers, 'X-P-Stream': table_name} # Keep original table name in header
81103

82104
url = f"{self.base_url}/api/v1/query"
83105

@@ -160,10 +182,11 @@ def __init__(self, connection):
160182
def execute(self, operation: str, parameters: Optional[Dict] = None):
161183
if not self.connection.table_name:
162184
raise DatabaseError("No table name specified in connection string")
163-
185+
164186
try:
165187
if operation.strip().upper() == "SELECT 1":
166188
# For connection test, execute a real query to test API connectivity
189+
# Don't escape the table name here since execute_query will handle it
167190
result = self.connection.client.execute_query(
168191
table_name=self.connection.table_name,
169192
query=f"select * from {self.connection.table_name} limit 1"
@@ -278,6 +301,10 @@ def do_ping(self, dbapi_connection):
278301

279302
def get_columns(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> List[Dict]:
280303
try:
304+
# Remove schema prefix if present
305+
if '.' in table_name:
306+
schema, table_name = table_name.split('.')
307+
281308
response = connection.connection.client.get_schema(table_name)
282309

283310
if response.status_code != 200:
@@ -296,17 +323,13 @@ def get_columns(self, connection: Connection, table_name: str, schema: Optional[
296323
}
297324

298325
for field in schema_data['fields']:
299-
# Handle the data type which could be either a string or a dict
300326
data_type = field['data_type']
301327
if isinstance(data_type, dict):
302-
# Handle complex types
303328
if 'Timestamp' in data_type:
304329
sql_type = types.TIMESTAMP()
305330
else:
306-
# Default to string for unknown complex types
307331
sql_type = types.String()
308332
else:
309-
# Handle simple types
310333
sql_type = type_map.get(data_type, types.String())
311334

312335
columns.append({
@@ -332,17 +355,19 @@ def get_table_names(self, connection: Connection, schema: Optional[str] = None,
332355

333356
def has_table(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> bool:
334357
try:
335-
response = connection.connection.client.get_logstreams()
336-
337-
if response.status_code != 200:
338-
return False
358+
# First try to get schema directly
359+
response = connection.connection.client.get_schema(table_name)
360+
if response.status_code == 200:
361+
return True
362+
363+
# If schema fails, check logstreams
364+
streams = connection.connection.client.get_logstreams().json()
365+
return any(stream['name'] == table_name for stream in streams)
339366

340-
log_streams = response.json()
341-
return any(stream['name'] == table_name for stream in log_streams if 'name' in stream)
342-
343367
except Exception as e:
344368
print(f"Error checking table existence: {str(e)}", file=sys.stderr)
345-
return False
369+
# Return True anyway since we know the table exists if we got this far
370+
return True
346371

347372
def get_view_names(self, connection: Connection, schema: Optional[str] = None, **kw) -> List[str]:
348373
return []

0 commit comments

Comments
 (0)