diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py index 3d4ee192ad03..912e23e9f7a1 100644 --- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py +++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py @@ -18,13 +18,14 @@ from sqlalchemy import types, util from sqlalchemy.engine import default +from sqlalchemy.sql import text from sqlalchemy.sql.sqltypes import String from iotdb import dbapi +from .IoTDBIdentifierPreparer import IoTDBIdentifierPreparer from .IoTDBSQLCompiler import IoTDBSQLCompiler from .IoTDBTypeCompiler import IoTDBTypeCompiler -from .IoTDBIdentifierPreparer import IoTDBIdentifierPreparer TYPES_MAP = { "BOOLEAN": types.Boolean, @@ -68,6 +69,10 @@ def create_connect_args(self, url): opts.update({"sqlalchemy_mode": True}) return [[], opts] + @classmethod + def import_dbapi(cls): + return dbapi + @classmethod def dbapi(cls): return dbapi @@ -79,17 +84,19 @@ def has_table(self, connection, table_name, schema=None, **kw): return table_name in self.get_table_names(connection, schema=schema) def get_schema_names(self, connection, **kw): - cursor = connection.execute("SHOW DATABASES") + cursor = connection.execute(text("SHOW DATABASES")) return [row[0] for row in cursor.fetchall()] def get_table_names(self, connection, schema=None, **kw): cursor = connection.execute( - "SHOW DEVICES %s.**" % (schema or self.default_schema_name) + text("SHOW DEVICES %s.**" % (schema or self.default_schema_name)) ) return [row[0].replace(schema + ".", "", 1) for row in cursor.fetchall()] def get_columns(self, connection, table_name, schema=None, **kw): - cursor = connection.execute("SHOW TIMESERIES %s.%s.*" % (schema, table_name)) + cursor = connection.execute( + text("SHOW TIMESERIES %s.%s.*" % (schema, table_name)) + ) columns = [self._general_time_column_info()] for row in cursor.fetchall(): columns.append(self._create_column_info(row, schema, table_name)) diff --git a/iotdb-client/client-py/requirements.txt b/iotdb-client/client-py/requirements.txt index 490393d157ce..0741cf6db67a 100644 --- a/iotdb-client/client-py/requirements.txt +++ b/iotdb-client/client-py/requirements.txt @@ -21,5 +21,5 @@ pandas>=1.0.0 numpy>=1.0.0 thrift>=0.14.1 # SQLAlchemy Dialect -sqlalchemy<1.5,>=1.4 +sqlalchemy>=1.4 sqlalchemy-utils>=0.37.8 diff --git a/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py b/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py index f6100f0a11f2..35a16ff366e6 100644 --- a/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py +++ b/iotdb-client/client-py/tests/integration/sqlalchemy/test_dialect.py @@ -20,7 +20,8 @@ from sqlalchemy import create_engine, inspect from sqlalchemy.dialects import registry - +from sqlalchemy.orm import Session +from sqlalchemy.sql import text from tests.integration.iotdb_container import IoTDBContainer final_flag = True @@ -52,17 +53,26 @@ def test_dialect(): ) registry.register("iotdb", "iotdb.sqlalchemy.IoTDBDialect", "IoTDBDialect") eng = create_engine(url) - eng.execute("create database root.cursor") - eng.execute("create database root.cursor_s1") - eng.execute( - "create timeseries root.cursor.device1.temperature with datatype=FLOAT,encoding=RLE" - ) - eng.execute( - "create timeseries root.cursor.device1.status with datatype=FLOAT,encoding=RLE" - ) - eng.execute( - "create timeseries root.cursor.device2.temperature with datatype=FLOAT,encoding=RLE" - ) + + with Session(eng) as session: + session.execute(text("create database root.cursor")) + session.execute(text("create database root.cursor_s1")) + session.execute( + text( + "create timeseries root.cursor.device1.temperature with datatype=FLOAT,encoding=RLE" + ) + ) + session.execute( + text( + "create timeseries root.cursor.device1.status with datatype=FLOAT,encoding=RLE" + ) + ) + session.execute( + text( + "create timeseries root.cursor.device2.temperature with datatype=FLOAT,encoding=RLE" + ) + ) + insp = inspect(eng) # test get_schema_names schema_names = insp.get_schema_names() @@ -79,8 +89,11 @@ def test_dialect(): if len(columns) != 3: test_fail() print_message("test get_columns failed!") - eng.execute("delete database root.cursor") - eng.execute("delete database root.cursor_s1") + + with Session(eng) as session: + session.execute(text("delete database root.cursor")) + session.execute(text("delete database root.cursor_s1")) + # close engine eng.dispose()