11from concurrent .futures import ThreadPoolExecutor
2- from sqlalchemy import Table , Column , MetaData , select , testing , text
3- from sqlalchemy .orm import declarative_base
2+ from sqlalchemy import Column , select , testing , text
3+ from sqlalchemy .orm import Session
44from sqlalchemy .orm import sessionmaker
55from sqlalchemy .testing import fixtures
66from sqlalchemy .types import Integer
77import threading
88
99from sqlalchemy_cockroachdb import run_transaction
1010
11- """ This file is named "test_aaa_run_transaction_session.py" to ensure that it is run before any
12- other tests. Testing under SQLA 1.4 revealed that it ran by itself just fine, but if *any*
13- other test ran before it (even the corresponding "core" test) then it would crash with
1411
15- sqlalchemy.exc.ArgumentError: Column expression or FROM clause expected,
16- got <class 'test.test_run_transaction.Account'>.
17- """
18- # TODO: Investigate SQLA_1.4 testing configuration to try and determine why this is happening.
19- # (It didn't happen with the old name under SQLA_1.3.)
12+ class BaseRunTransactionTest (fixtures .DeclarativeMappedTest ):
13+ @classmethod
14+ def setup_classes (cls ):
15+ Base = cls .DeclarativeBasic
2016
21- meta = MetaData ()
17+ class Account (Base ):
18+ __tablename__ = "account"
2219
23- account_table = Table (
24- "account" ,
25- meta ,
26- Column ("acct" , Integer , primary_key = True , autoincrement = False ),
27- Column ("balance" , Integer ),
28- )
20+ acct = Column (Integer , primary_key = True , autoincrement = False )
21+ balance = Column (Integer )
2922
23+ @classmethod
24+ def insert_data (cls , connection ):
25+ Account = cls .classes .Account
3026
31- class Account (declarative_base ()):
32- __table__ = account_table
33-
34-
35- class BaseRunTransactionTest (fixtures .TestBase ):
36- def setup_method (self , method ):
37- meta .create_all (testing .db )
38- with testing .db .begin () as conn :
39- conn .execute (
40- account_table .insert (), [dict (acct = 1 , balance = 100 ), dict (acct = 2 , balance = 100 )]
41- )
42-
43- def teardown_method (self , method ):
44- meta .drop_all (testing .db )
27+ session = Session (connection )
28+ session .add_all ([Account (acct = 1 , balance = 100 ), Account (acct = 2 , balance = 100 )])
29+ session .commit ()
4530
4631 def get_balances (self , conn ):
32+ Account = self .classes .Account
33+
4734 """Returns the balances of the two accounts as a list."""
4835 result = []
49- query = (
50- select (account_table .c .balance )
51- .where (account_table .c .acct .in_ ((1 , 2 )))
52- .order_by (account_table .c .acct )
53- )
36+ query = select (Account .balance ).where (Account .acct .in_ ((1 , 2 ))).order_by (Account .acct )
5437 for row in conn .execute (query ):
5538 result .append (row .balance )
5639 if len (result ) != 2 :
5740 raise Exception ("Expected two balances; got %d" , len (result ))
5841 return result
5942
60- def run_parallel_transactions (self , callback ):
43+ def run_parallel_transactions (self , callback , conn ):
6144 """Runs the callback in two parallel transactions.
6245
6346 A barrier function is passed to the callback and should be run
@@ -96,7 +79,7 @@ def barrier():
9679 iters1 ,
9780 iters2 ,
9881 )
99- balances = self .get_balances (testing . db . connect () )
82+ balances = self .get_balances (conn )
10083 assert balances == [100 , 100 ], (
10184 "expected balances to be restored without error; " "got %s" % balances
10285 )
@@ -105,7 +88,9 @@ def barrier():
10588class RunTransactionSessionTest (BaseRunTransactionTest ):
10689 __requires__ = ("sync_driver" ,)
10790
108- def test_run_transaction (self ):
91+ def test_run_transaction (self , connection ):
92+ Account = self .classes .Account
93+
10994 def callback (barrier ):
11095 Session = sessionmaker (testing .db )
11196
@@ -121,22 +106,18 @@ def txn_body(session):
121106 accounts [0 ].balance += 100
122107 accounts [1 ].balance -= 100
123108
124- with testing .expect_deprecated_20 (
125- "The Session.autocommit parameter is deprecated"
126- ):
109+ with testing .expect_deprecated_20 ("The Session.autocommit parameter is deprecated" ):
127110 run_transaction (Session , txn_body )
128111
129- self .run_parallel_transactions (callback )
112+ self .run_parallel_transactions (callback , connection )
130113
131114 def test_run_transaction_retry (self ):
132115 def txn_body (sess ):
133116 rs = sess .execute (text ("select acct, balance from account where acct = 1" ))
134117 sess .execute (text ("select crdb_internal.force_retry('1s')" ))
135118 return [r for r in rs ]
136119
137- with testing .expect_deprecated_20 (
138- "The Session.autocommit parameter is deprecated"
139- ):
120+ with testing .expect_deprecated_20 ("The Session.autocommit parameter is deprecated" ):
140121 Session = sessionmaker (testing .db )
141122 rs = run_transaction (Session , txn_body )
142123 assert rs [0 ] == (1 , 100 )
0 commit comments