Skip to content
This repository was archived by the owner on May 5, 2022. It is now read-only.

Commit b8e9aff

Browse files
authored
Merge pull request #10 from dungdm93/support-pandas
Supporting pandas
2 parents 0a0cfcd + 39526be commit b8e9aff

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

sqlalchemy_trino/compiler.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,44 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
8585

8686

8787
class TrinoTypeCompiler(compiler.GenericTypeCompiler):
88-
pass
88+
def visit_FLOAT(self, type_, **kw):
89+
precision = type_.precision or 32
90+
if 0 <= precision <= 32:
91+
return self.visit_REAL(type_, **kw)
92+
elif 32 < precision <= 64:
93+
return self.visit_DOUBLE(type_, **kw)
94+
else:
95+
raise ValueError(f"type.precision={type_.precision} is invalid")
96+
97+
def visit_DOUBLE(self, type_, **kw):
98+
return "DOUBLE"
99+
100+
def visit_NUMERIC(self, type_, **kw):
101+
return self.visit_DECIMAL(type_, **kw)
102+
103+
def visit_NCHAR(self, type_, **kw):
104+
return self.visit_CHAR(type_, **kw)
105+
106+
def visit_NVARCHAR(self, type_, **kw):
107+
return self.visit_VARCHAR(type_, **kw)
108+
109+
def visit_TEXT(self, type_, **kw):
110+
return self.visit_VARCHAR(type_, **kw)
111+
112+
def visit_BINARY(self, type_, **kw):
113+
return self.visit_VARBINARY(type_, **kw)
114+
115+
def visit_CLOB(self, type_, **kw):
116+
return self.visit_VARCHAR(type_, **kw)
117+
118+
def visit_NCLOB(self, type_, **kw):
119+
return self.visit_VARCHAR(type_, **kw)
120+
121+
def visit_BLOB(self, type_, **kw):
122+
return self.visit_VARBINARY(type_, **kw)
123+
124+
def visit_DATETIME(self, type_, **kw):
125+
return self.visit_TIMESTAMP(type_, **kw)
89126

90127

91128
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):

sqlalchemy_trino/dialect.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
88
from sqlalchemy.engine.url import URL
99
from trino.auth import BasicAuthentication
10+
from trino.client import TrinoQuery
11+
from trino.dbapi import Cursor
1012

1113
from . import compiler
1214
from . import datatype
@@ -58,6 +60,7 @@ class TrinoDialect(DefaultDialect):
5860
# DML
5961
supports_empty_insert = False
6062
supports_multivalues_insert = True
63+
postfetch_lastrowid = False
6164

6265
# Version parser
6366
__version_pattern = re.compile(r'(\d+).*')
@@ -272,6 +275,19 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
272275
dbapi_connection: trino_dbapi.Connection = connection.connection
273276
return dbapi_connection.schema
274277

278+
def do_execute(self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...],
279+
context: DefaultExecutionContext = None):
280+
cursor.execute(statement, parameters)
281+
if context and context.should_autocommit:
282+
# SQL statement only submitted to Trino server when cursor.fetch*() is called.
283+
# For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description
284+
# to force submit statement immediately.
285+
d = cursor.description
286+
# old trino client does not support eager-loading cursor.description
287+
if d is None:
288+
query: TrinoQuery = cursor._query # noqa
289+
query._result._rows += query.fetch() # noqa
290+
275291
def do_rollback(self, dbapi_connection):
276292
if dbapi_connection.transaction is not None:
277293
dbapi_connection.rollback()

0 commit comments

Comments
 (0)