Skip to content

Commit 1540c41

Browse files
gordthompsonrafiss
authored andcommitted
Adjust test_run_transaction_session.py for 1.4
Restructure tests to avoid requirement that they must be run before any other tests. Rename .py files accordingly.
1 parent 27f82b6 commit 1540c41

File tree

2 files changed

+27
-50
lines changed

2 files changed

+27
-50
lines changed
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66

77
from sqlalchemy_cockroachdb import run_transaction
88

9-
""" This file is named "test_aab_run_transaction_core.py" to keep it close to its more
10-
temperamental "session" sibling.
11-
"""
12-
139
meta = MetaData()
1410

1511
account_table = Table(
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,46 @@
11
from concurrent.futures import ThreadPoolExecutor
2-
from sqlalchemy import Table, Column, MetaData, select, testing, text
3-
from sqlalchemy.orm import declarative_base
2+
from sqlalchemy import Column, select, testing, text
3+
from sqlalchemy.orm import Session
44
from sqlalchemy.orm import sessionmaker
55
from sqlalchemy.testing import fixtures
66
from sqlalchemy.types import Integer
77
import threading
88

99
from sqlalchemy_cockroachdb import run_transaction
1010

11-
""" This file is named "test_aaa_run_transaction_session.py" to ensure that it is run before any
12-
other tests. Testing under SQLA 1.4 revealed that it ran by itself just fine, but if *any*
13-
other test ran before it (even the corresponding "core" test) then it would crash with
1411

15-
sqlalchemy.exc.ArgumentError: Column expression or FROM clause expected,
16-
got <class 'test.test_run_transaction.Account'>.
17-
"""
18-
# TODO: Investigate SQLA_1.4 testing configuration to try and determine why this is happening.
19-
# (It didn't happen with the old name under SQLA_1.3.)
12+
class BaseRunTransactionTest(fixtures.DeclarativeMappedTest):
13+
@classmethod
14+
def setup_classes(cls):
15+
Base = cls.DeclarativeBasic
2016

21-
meta = MetaData()
17+
class Account(Base):
18+
__tablename__ = "account"
2219

23-
account_table = Table(
24-
"account",
25-
meta,
26-
Column("acct", Integer, primary_key=True, autoincrement=False),
27-
Column("balance", Integer),
28-
)
20+
acct = Column(Integer, primary_key=True, autoincrement=False)
21+
balance = Column(Integer)
2922

23+
@classmethod
24+
def insert_data(cls, connection):
25+
Account = cls.classes.Account
3026

31-
class Account(declarative_base()):
32-
__table__ = account_table
33-
34-
35-
class BaseRunTransactionTest(fixtures.TestBase):
36-
def setup_method(self, method):
37-
meta.create_all(testing.db)
38-
with testing.db.begin() as conn:
39-
conn.execute(
40-
account_table.insert(), [dict(acct=1, balance=100), dict(acct=2, balance=100)]
41-
)
42-
43-
def teardown_method(self, method):
44-
meta.drop_all(testing.db)
27+
session = Session(connection)
28+
session.add_all([Account(acct=1, balance=100), Account(acct=2, balance=100)])
29+
session.commit()
4530

4631
def get_balances(self, conn):
32+
Account = self.classes.Account
33+
4734
"""Returns the balances of the two accounts as a list."""
4835
result = []
49-
query = (
50-
select(account_table.c.balance)
51-
.where(account_table.c.acct.in_((1, 2)))
52-
.order_by(account_table.c.acct)
53-
)
36+
query = select(Account.balance).where(Account.acct.in_((1, 2))).order_by(Account.acct)
5437
for row in conn.execute(query):
5538
result.append(row.balance)
5639
if len(result) != 2:
5740
raise Exception("Expected two balances; got %d", len(result))
5841
return result
5942

60-
def run_parallel_transactions(self, callback):
43+
def run_parallel_transactions(self, callback, conn):
6144
"""Runs the callback in two parallel transactions.
6245
6346
A barrier function is passed to the callback and should be run
@@ -96,7 +79,7 @@ def barrier():
9679
iters1,
9780
iters2,
9881
)
99-
balances = self.get_balances(testing.db.connect())
82+
balances = self.get_balances(conn)
10083
assert balances == [100, 100], (
10184
"expected balances to be restored without error; " "got %s" % balances
10285
)
@@ -105,7 +88,9 @@ def barrier():
10588
class RunTransactionSessionTest(BaseRunTransactionTest):
10689
__requires__ = ("sync_driver",)
10790

108-
def test_run_transaction(self):
91+
def test_run_transaction(self, connection):
92+
Account = self.classes.Account
93+
10994
def callback(barrier):
11095
Session = sessionmaker(testing.db)
11196

@@ -121,22 +106,18 @@ def txn_body(session):
121106
accounts[0].balance += 100
122107
accounts[1].balance -= 100
123108

124-
with testing.expect_deprecated_20(
125-
"The Session.autocommit parameter is deprecated"
126-
):
109+
with testing.expect_deprecated_20("The Session.autocommit parameter is deprecated"):
127110
run_transaction(Session, txn_body)
128111

129-
self.run_parallel_transactions(callback)
112+
self.run_parallel_transactions(callback, connection)
130113

131114
def test_run_transaction_retry(self):
132115
def txn_body(sess):
133116
rs = sess.execute(text("select acct, balance from account where acct = 1"))
134117
sess.execute(text("select crdb_internal.force_retry('1s')"))
135118
return [r for r in rs]
136119

137-
with testing.expect_deprecated_20(
138-
"The Session.autocommit parameter is deprecated"
139-
):
120+
with testing.expect_deprecated_20("The Session.autocommit parameter is deprecated"):
140121
Session = sessionmaker(testing.db)
141122
rs = run_transaction(Session, txn_body)
142123
assert rs[0] == (1, 100)

0 commit comments

Comments
 (0)