22from decimal import Decimal
33
44import pytest
5- from sqlalchemy import create_engine
5+ from sqlalchemy import create_engine , text
66from sqlalchemy .engine .base import Connection , Engine
77from sqlalchemy .exc import OperationalError
88
@@ -15,11 +15,11 @@ def test_create_ex_table(
1515 ex_table_query : str ,
1616 ex_table_name : str ,
1717 ):
18- connection .execute (ex_table_query )
19- assert engine .dialect .has_table (engine , ex_table_name )
18+ connection .execute (text ( ex_table_query ) )
19+ assert engine .dialect .has_table (connection , ex_table_name )
2020 # Cleanup
21- connection .execute (f"DROP TABLE { ex_table_name } " )
22- assert not engine .dialect .has_table (engine , ex_table_name )
21+ connection .execute (text ( f"DROP TABLE { ex_table_name } " ) )
22+ assert not engine .dialect .has_table (connection , ex_table_name )
2323
2424 def test_set_params (
2525 self , username : str , password : str , database_name : str , engine_name : str
@@ -28,81 +28,95 @@ def test_set_params(
2828 f"firebolt://{ username } :{ password } @{ database_name } /{ engine_name } "
2929 )
3030 with engine .connect () as connection :
31- connection .execute ("SET advanced_mode=1" )
32- connection .execute ("SET use_standard_sql=0" )
33- result = connection .execute ("SELECT sleepEachRow(1) from numbers(1)" )
31+ connection .execute (text ( "SET advanced_mode=1" ) )
32+ connection .execute (text ( "SET use_standard_sql=0" ) )
33+ result = connection .execute (text ( "SELECT sleepEachRow(1) from numbers(1)" ) )
3434 assert len (result .fetchall ()) == 1
3535 engine .dispose ()
3636
3737 def test_data_write (self , connection : Connection , fact_table_name : str ):
3838 connection .execute (
39- f"INSERT INTO { fact_table_name } (idx, dummy) VALUES (1, 'some_text')"
39+ text (f"INSERT INTO { fact_table_name } (idx, dummy) VALUES (1, 'some_text')" )
40+ )
41+ result = connection .execute (
42+ text (f"SELECT * FROM { fact_table_name } WHERE idx=1" )
4043 )
41- result = connection .execute (f"SELECT * FROM { fact_table_name } WHERE idx=?" , 1 )
4244 assert result .fetchall () == [(1 , "some_text" )]
43- result = connection .execute (f"SELECT * FROM { fact_table_name } " )
45+ result = connection .execute (text ( f"SELECT * FROM { fact_table_name } " ) )
4446 assert len (result .fetchall ()) == 1
4547 # Update not supported
4648 with pytest .raises (OperationalError ):
4749 connection .execute (
48- f"UPDATE { fact_table_name } SET dummy='some_other_text' WHERE idx=1"
50+ text (
51+ f"UPDATE { fact_table_name } SET dummy='some_other_text' WHERE idx=1"
52+ )
4953 )
5054 # Delete works but is not officially supported yet
5155 # with pytest.raises(OperationalError):
5256 # connection.execute(f"DELETE FROM {fact_table_name} WHERE idx=1")
5357
5458 def test_firebolt_types (self , connection : Connection ):
55- result = connection .execute ("SELECT '1896-01-01' :: DATE_EXT" )
59+ result = connection .execute (text ( "SELECT '1896-01-01' :: DATE_EXT" ) )
5660 assert result .fetchall () == [(date (1896 , 1 , 1 ),)]
57- result = connection .execute ("SELECT '1896-01-01 00:01:00' :: TIMESTAMP_EXT" )
61+ result = connection .execute (
62+ text ("SELECT '1896-01-01 00:01:00' :: TIMESTAMP_EXT" )
63+ )
5864 assert result .fetchall () == [(datetime (1896 , 1 , 1 , 0 , 1 , 0 , 0 ),)]
59- result = connection .execute ("SELECT 100.76 :: DECIMAL(5, 2)" )
65+ result = connection .execute (text ( "SELECT 100.76 :: DECIMAL(5, 2)" ) )
6066 assert result .fetchall () == [(Decimal ("100.76" ),)]
6167
6268 def test_agg_index (self , connection : Connection , fact_table_name : str ):
6369 # Test if sql parsing allows it
6470 agg_index = "idx_agg_max"
6571 connection .execute (
66- f"""
72+ text (
73+ f"""
6774 CREATE AGGREGATING INDEX { agg_index } ON { fact_table_name } (
6875 dummy,
6976 max(idx)
7077 );
7178 """
79+ )
7280 )
73- connection .execute (f"DROP AGGREGATING INDEX { agg_index } " )
81+ connection .execute (text ( f"DROP AGGREGATING INDEX { agg_index } " ) )
7482
7583 def test_join_index (self , connection : Connection , dimension_table_name : str ):
7684 # Test if sql parsing allows it
7785 join_index = "idx_join"
7886 connection .execute (
79- f"""
87+ text (
88+ f"""
8089 CREATE JOIN INDEX { join_index } ON { dimension_table_name } (
8190 idx,
8291 dummy
8392 );
8493 """
94+ )
8595 )
86- connection .execute (f"DROP JOIN INDEX { join_index } " )
96+ connection .execute (text ( f"DROP JOIN INDEX { join_index } " ) )
8797
8898 def test_get_schema_names (self , engine : Engine , database_name : str ):
8999 results = engine .dialect .get_schema_names (engine )
90100 assert "public" in results
91101
92- def test_has_table (self , engine : Engine , fact_table_name : str ):
93- results = engine .dialect .has_table (engine , fact_table_name )
102+ def test_has_table (
103+ self , engine : Engine , connection : Connection , fact_table_name : str
104+ ):
105+ results = engine .dialect .has_table (connection , fact_table_name )
94106 assert results == 1
95107
96- def test_get_table_names (self , engine : Engine ):
97- results = engine .dialect .get_table_names (engine )
108+ def test_get_table_names (self , engine : Engine , connection : Connection ):
109+ results = engine .dialect .get_table_names (connection )
98110 assert len (results ) > 0
99- results = engine .dialect .get_table_names (engine , "public" )
111+ results = engine .dialect .get_table_names (connection , "public" )
100112 assert len (results ) > 0
101- results = engine .dialect .get_table_names (engine , "non_existing_schema" )
113+ results = engine .dialect .get_table_names (connection , "non_existing_schema" )
102114 assert len (results ) == 0
103115
104- def test_get_columns (self , engine : Engine , fact_table_name : str ):
105- results = engine .dialect .get_columns (engine , fact_table_name )
116+ def test_get_columns (
117+ self , engine : Engine , connection : Connection , fact_table_name : str
118+ ):
119+ results = engine .dialect .get_columns (connection , fact_table_name )
106120 assert len (results ) > 0
107121 row = results [0 ]
108122 assert isinstance (row , dict )
@@ -113,5 +127,5 @@ def test_get_columns(self, engine: Engine, fact_table_name: str):
113127 assert row_keys [3 ] == "default"
114128
115129 def test_service_account_connect (self , connection_service_account : Connection ):
116- result = connection_service_account .execute ("SELECT 1" )
130+ result = connection_service_account .execute (text ( "SELECT 1" ) )
117131 assert result .fetchall () == [(1 ,)]
0 commit comments