Skip to content

Commit c84a56c

Browse files
committed
Fix dbapi schema handling
1 parent 04c9f62 commit c84a56c

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

chdb/dbapi/connections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def cursor(self, cursor=None):
5757
return Cursor(self)
5858
return Cursor(self)
5959

60-
def query(self, sql, fmt="ArrowStream"):
60+
def query(self, sql, fmt="CSV"):
6161
"""Execute a query and return the raw result."""
6262
if self._closed:
6363
raise err.InterfaceError("Connection closed")

chdb/dbapi/cursors.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# executemany only supports simple bulk insert.
66
# You can use it to load large dataset.
77
RE_INSERT_VALUES = re.compile(
8-
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" +
9-
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
10-
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
11-
re.IGNORECASE | re.DOTALL)
8+
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
9+
+ r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
10+
+ r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
11+
re.IGNORECASE | re.DOTALL,
12+
)
1213

1314

1415
class Cursor(object):
@@ -131,13 +132,17 @@ def execute(self, query, args=None):
131132

132133
self._cursor.execute(query)
133134

134-
# Get description from Arrow schema
135-
if self._cursor._current_table is not None:
135+
# Get description from column names and types
136+
if hasattr(self._cursor, "_column_names") and self._cursor._column_names:
136137
self.description = [
137-
(field.name, field.type.to_pandas_dtype(), None, None, None, None, None)
138-
for field in self._cursor._current_table.schema
138+
(name, type_info, None, None, None, None, None)
139+
for name, type_info in zip(
140+
self._cursor._column_names, self._cursor._column_types
141+
)
139142
]
140-
self.rowcount = self._cursor._current_table.num_rows
143+
self.rowcount = (
144+
len(self._cursor._current_table) if self._cursor._current_table else -1
145+
)
141146
else:
142147
self.description = None
143148
self.rowcount = -1
@@ -164,16 +169,23 @@ def executemany(self, query, args):
164169
if m:
165170
q_prefix = m.group(1) % ()
166171
q_values = m.group(2).rstrip()
167-
q_postfix = m.group(3) or ''
168-
assert q_values[0] == '(' and q_values[-1] == ')'
169-
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
170-
self.max_stmt_length,
171-
self._get_db().encoding)
172+
q_postfix = m.group(3) or ""
173+
assert q_values[0] == "(" and q_values[-1] == ")"
174+
return self._do_execute_many(
175+
q_prefix,
176+
q_values,
177+
q_postfix,
178+
args,
179+
self.max_stmt_length,
180+
self._get_db().encoding,
181+
)
172182

173183
self.rowcount = sum(self.execute(query, arg) for arg in args)
174184
return self.rowcount
175185

176-
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
186+
def _do_execute_many(
187+
self, prefix, values, postfix, args, max_stmt_length, encoding
188+
):
177189
conn = self._get_db()
178190
escape = self._escape_args
179191
if isinstance(prefix, str):
@@ -184,18 +196,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encod
184196
args = iter(args)
185197
v = values % escape(next(args), conn)
186198
if isinstance(v, str):
187-
v = v.encode(encoding, 'surrogateescape')
199+
v = v.encode(encoding, "surrogateescape")
188200
sql += v
189201
rows = 0
190202
for arg in args:
191203
v = values % escape(arg, conn)
192204
if isinstance(v, str):
193-
v = v.encode(encoding, 'surrogateescape')
205+
v = v.encode(encoding, "surrogateescape")
194206
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
195207
rows += self.execute(sql + postfix)
196208
sql = prefix
197209
else:
198-
sql += ','.encode(encoding)
210+
sql += ",".encode(encoding)
199211
sql += v
200212
rows += self.execute(sql + postfix)
201213
self.rowcount = rows

0 commit comments

Comments
 (0)