Skip to content

Commit 7ceecdc

Browse files
authored
Fix a Constraint Violation Bug in postgres_persistence (#139)
* fix a constraint violation bug in postgres_persistence. * update postgres unit tests * modify id values to use class-level constant instead of hardcoded value. add pre_ping for graceful disconnect handling. add session.remove() and replaced instances of .close() with .remove(). also updated tests accordingly. * replace f-strings with parametrized queries to avoid codacy error
1 parent e4b7922 commit 7ceecdc

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

ptbcontrib/postgres_persistence/postgrespersistence.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class PostgresPersistence(DictPersistence):
4747
persistence instance.
4848
"""
4949

50+
PERSISTENCE_ID = 1
51+
5052
def __init__(
5153
self,
5254
url: str = None,
@@ -57,7 +59,7 @@ def __init__(
5759
if url:
5860
if not url.startswith("postgresql://"):
5961
raise TypeError(f"{url} isn't a valid PostgreSQL database URL.")
60-
engine = create_engine(url, client_encoding="utf8")
62+
engine = create_engine(url, client_encoding="utf8", pool_pre_ping=True)
6163
self._session = scoped_session(sessionmaker(bind=engine, autoflush=False))
6264

6365
elif session:
@@ -90,9 +92,11 @@ def __init__(
9092
# `UPDATE` operations if column have some data already present inside it.
9193
if not data:
9294
upsert_qry = """
93-
INSERT INTO persistence (data) VALUES (:jsondata)
95+
INSERT INTO persistence (id, data) VALUES (:id, :jsondata)
9496
ON CONFLICT (id) DO UPDATE SET data = :jsondata"""
95-
self._session.execute(text(upsert_qry), {"jsondata": "{}"})
97+
self._session.execute(
98+
text(upsert_qry), {"id": self.PERSISTENCE_ID, "jsondata": "{}"}
99+
)
96100
self._session.commit()
97101

98102
super().__init__(
@@ -104,7 +108,7 @@ def __init__(
104108
conversations_json=conversations_json,
105109
)
106110
finally:
107-
self._session.close()
111+
self._session.remove()
108112

109113
def __init_database(self) -> None:
110114
"""
@@ -113,11 +117,11 @@ def __init_database(self) -> None:
113117
runs schema migration if necessary.
114118
"""
115119
try:
116-
create_table_qry = """
120+
create_table_qry = f"""
117121
CREATE TABLE IF NOT EXISTS persistence(
118-
id INT PRIMARY KEY DEFAULT 1,
122+
id INT PRIMARY KEY DEFAULT {self.PERSISTENCE_ID},
119123
data json NOT NULL,
120-
CONSTRAINT single_row CHECK (id = 1));"""
124+
CONSTRAINT single_row CHECK (id = {self.PERSISTENCE_ID}));"""
121125
self._session.execute(text(create_table_qry))
122126

123127
# Check if id column exists, is an integer type, and is a primary key
@@ -141,26 +145,28 @@ def __init_database(self) -> None:
141145
data_valid = False
142146
if column_valid:
143147
check_data_qry = """
144-
SELECT 1 FROM persistence WHERE id = 1;"""
145-
data_valid = self._session.execute(text(check_data_qry)).first() is not None
148+
SELECT 1 FROM persistence WHERE id = :id;"""
149+
result = self._session.execute(text(check_data_qry), {"id": self.PERSISTENCE_ID})
150+
data_valid = result.first() is not None
146151

147152
needs_migration = not (column_valid and data_valid)
148153

149154
if needs_migration:
150155
self.logger.info("Old database schema detected. Running migration...")
151156
migration_commands = [
152157
"ALTER TABLE persistence ADD COLUMN id INT;",
153-
"""
154-
UPDATE persistence SET id = 1 WHERE ctid = (
155-
SELECT ctid FROM persistence LIMIT 1
156-
);""",
158+
"""UPDATE persistence SET id = :id WHERE ctid = ("
159+
"SELECT ctid FROM persistence LIMIT 1);""",
157160
"DELETE FROM persistence WHERE id IS NULL;",
158161
"ALTER TABLE persistence ALTER COLUMN id SET NOT NULL;",
159162
"ALTER TABLE persistence ADD PRIMARY KEY (id);",
160-
"ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = 1);",
163+
"ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = :id);",
161164
]
162165
for command in migration_commands:
163-
self._session.execute(text(command))
166+
if ":id" in command:
167+
self._session.execute(text(command), {"id": self.PERSISTENCE_ID})
168+
else:
169+
self._session.execute(text(command))
164170
self.logger.info("Database migration successful!")
165171

166172
self._session.commit()
@@ -187,9 +193,9 @@ def _update_database(self) -> None:
187193
self.logger.debug("Updating database...")
188194
try:
189195
upsert_qry = """
190-
INSERT INTO persistence (data) VALUES (:jsondata)
196+
INSERT INTO persistence (id, data) VALUES (:id, :jsondata)
191197
ON CONFLICT (id) DO UPDATE SET data = :jsondata"""
192-
params = {"jsondata": self._dump_into_json()}
198+
params = {"id": self.PERSISTENCE_ID, "jsondata": self._dump_into_json()}
193199
self._session.execute(text(upsert_qry), params)
194200
self._session.commit()
195201
except Exception as excp: # pylint: disable=W0703
@@ -198,6 +204,8 @@ def _update_database(self) -> None:
198204
exc_info=excp,
199205
)
200206
self._session.rollback()
207+
finally:
208+
self._session.remove()
201209

202210
async def update_conversation(
203211
self, name: str, key: Tuple[int, ...], new_state: Optional[object]

tests/test_postgres_persistence.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def test_with_handler(self, bot, update, monkeypatch):
8181
session = scoped_session("a")
8282
monkeypatch.setattr(session, "execute", self.mocked_execute)
8383
monkeypatch.setattr(session, "commit", self.mock_commit)
84-
monkeypatch.setattr(session, "close", self.mock_ses_close)
84+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
8585

8686
app = (
8787
Application.builder()
@@ -128,7 +128,7 @@ async def test_on_flush(self, bot, update, monkeypatch, on_flush, expected):
128128
session = scoped_session("a")
129129
monkeypatch.setattr(session, "execute", self.mocked_execute)
130130
monkeypatch.setattr(session, "commit", self.mock_commit)
131-
monkeypatch.setattr(session, "close", self.mock_ses_close)
131+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
132132

133133
persistence = PostgresPersistence(session=session, on_flush=on_flush)
134134

@@ -169,13 +169,13 @@ def test_load_on_boot(self, monkeypatch):
169169
session = scoped_session("a")
170170
monkeypatch.setattr(session, "execute", self.mocked_execute)
171171
monkeypatch.setattr(session, "commit", self.mock_commit)
172-
monkeypatch.setattr(session, "close", self.mock_ses_close)
172+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
173173

174174
PostgresPersistence(session=session)
175175
# Check that either SELECT or UPSERT query was executed (upsert for fresh db)
176176
executed_text = self.executed.text.strip()
177177
assert "SELECT data FROM persistence" in executed_text or (
178-
"INSERT INTO persistence (data) VALUES (:jsondata)" in executed_text
178+
"INSERT INTO persistence (id, data) VALUES (:id, :jsondata)" in executed_text
179179
and "ON CONFLICT (id) DO UPDATE SET data = :jsondata" in executed_text
180180
)
181181
assert self.commited == 555
@@ -185,7 +185,7 @@ async def test_flush(self, monkeypatch):
185185
session = scoped_session("a")
186186
monkeypatch.setattr(session, "execute", self.mocked_execute)
187187
monkeypatch.setattr(session, "commit", self.mock_commit)
188-
monkeypatch.setattr(session, "close", self.mock_ses_close)
188+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
189189

190190
await PostgresPersistence(session=session).flush()
191191
assert self.executed != ""
@@ -202,7 +202,7 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument
202202
session = scoped_session("a")
203203
monkeypatch.setattr(session, "execute", mock_execute)
204204
monkeypatch.setattr(session, "commit", self.mock_commit)
205-
monkeypatch.setattr(session, "close", self.mock_ses_close)
205+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
206206

207207
PostgresPersistence(session=session)
208208

@@ -252,22 +252,22 @@ def mock_execute(query, params=None):
252252
return FakeExecResultValidPK()
253253

254254
# Check for data validation query (id=1 exists)
255-
if "WHERE id = 1" in query.text and "information_schema" not in query.text:
255+
if "WHERE id = :id" in query.text and "information_schema" not in query.text:
256256
return FakeExecResultValidData()
257257

258258
return FakeExecResult()
259259

260260
session = scoped_session("a")
261261
monkeypatch.setattr(session, "execute", mock_execute)
262262
monkeypatch.setattr(session, "commit", self.mock_commit)
263-
monkeypatch.setattr(session, "close", self.mock_ses_close)
263+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
264264

265265
PostgresPersistence(session=session)
266266

267267
# Verify no migration commands were run
268268
migration_commands = [
269269
"ALTER TABLE persistence ADD COLUMN id INT",
270-
"UPDATE persistence SET id = 1",
270+
"UPDATE persistence SET id = :id",
271271
"DELETE FROM persistence WHERE id IS NULL",
272272
]
273273
for migration_cmd in migration_commands:
@@ -310,7 +310,7 @@ def mock_execute(query, params=None):
310310
session = scoped_session("a")
311311
monkeypatch.setattr(session, "execute", mock_execute)
312312
monkeypatch.setattr(session, "commit", self.mock_commit)
313-
monkeypatch.setattr(session, "close", self.mock_ses_close)
313+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
314314

315315
PostgresPersistence(session=session)
316316

@@ -340,18 +340,18 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument
340340
session = scoped_session("a")
341341
monkeypatch.setattr(session, "execute", mock_execute)
342342
monkeypatch.setattr(session, "commit", self.mock_commit)
343-
monkeypatch.setattr(session, "close", self.mock_ses_close)
343+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
344344

345345
PostgresPersistence(session=session)
346346

347347
# Verify migration commands were run in correct order
348348
expected_migration_steps = [
349349
"ALTER TABLE persistence ADD COLUMN id INT",
350-
"UPDATE persistence SET id = 1",
350+
"UPDATE persistence SET id = :id",
351351
"DELETE FROM persistence WHERE id IS NULL",
352352
"ALTER TABLE persistence ALTER COLUMN id SET NOT NULL",
353353
"ALTER TABLE persistence ADD PRIMARY KEY (id)",
354-
"ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = 1)",
354+
"ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = :id)",
355355
]
356356

357357
for expected_step in expected_migration_steps:
@@ -385,7 +385,7 @@ def mock_execute(query, params=None):
385385
session = scoped_session("a")
386386
monkeypatch.setattr(session, "execute", mock_execute)
387387
monkeypatch.setattr(session, "commit", self.mock_commit)
388-
monkeypatch.setattr(session, "close", self.mock_ses_close)
388+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
389389

390390
persistence = PostgresPersistence(session=session)
391391

@@ -403,13 +403,13 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument
403403
session = scoped_session("a")
404404
monkeypatch.setattr(session, "execute", mock_execute)
405405
monkeypatch.setattr(session, "commit", self.mock_commit)
406-
monkeypatch.setattr(session, "close", self.mock_ses_close)
406+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
407407

408408
PostgresPersistence(session=session)
409409

410410
# Check that upsert query was used for initialization
411411
upsert_found = any(
412-
"INSERT INTO persistence (data) VALUES (:jsondata)" in query
412+
"INSERT INTO persistence (id, data) VALUES (:id, :jsondata)" in query
413413
and "ON CONFLICT (id) DO UPDATE SET data = :jsondata" in query
414414
for query in executed_queries
415415
)
@@ -429,7 +429,7 @@ def mock_execute(query, params=None):
429429
session = scoped_session("a")
430430
monkeypatch.setattr(session, "execute", mock_execute)
431431
monkeypatch.setattr(session, "commit", self.mock_commit)
432-
monkeypatch.setattr(session, "close", self.mock_ses_close)
432+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
433433
monkeypatch.setattr(session, "rollback", lambda: None)
434434

435435
persistence = PostgresPersistence(session=session)
@@ -443,7 +443,7 @@ def mock_execute(query, params=None):
443443

444444
# Verify upsert query was used
445445
upsert_found = any(
446-
"INSERT INTO persistence (data) VALUES (:jsondata)" in query
446+
"INSERT INTO persistence (id, data) VALUES (:id, :jsondata)" in query
447447
and "ON CONFLICT (id) DO UPDATE SET data = :jsondata" in query
448448
for query in executed_queries
449449
)
@@ -452,6 +452,7 @@ def mock_execute(query, params=None):
452452
# Verify parameters were passed
453453
assert len(executed_params) > 0
454454
assert "jsondata" in executed_params[0]
455+
assert "id" in executed_params[0]
455456

456457
def test_single_row_constraint_in_schema(self, monkeypatch):
457458
"""Test that single_row constraint is present in schema"""
@@ -464,7 +465,7 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument
464465
session = scoped_session("a")
465466
monkeypatch.setattr(session, "execute", mock_execute)
466467
monkeypatch.setattr(session, "commit", self.mock_commit)
467-
monkeypatch.setattr(session, "close", self.mock_ses_close)
468+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
468469

469470
PostgresPersistence(session=session)
470471

@@ -493,13 +494,13 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument
493494
session = scoped_session("a")
494495
monkeypatch.setattr(session, "execute", mock_execute)
495496
monkeypatch.setattr(session, "commit", self.mock_commit)
496-
monkeypatch.setattr(session, "close", self.mock_ses_close)
497+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
497498

498499
PostgresPersistence(session=session)
499500

500501
# Verify that migration includes step to update first row to id=1
501502
update_first_row = any(
502-
"UPDATE persistence SET id = 1" in query and "LIMIT 1" in query
503+
"UPDATE persistence SET id = :id" in query and "LIMIT 1" in query
503504
for query in executed_queries
504505
)
505506
assert update_first_row
@@ -515,7 +516,7 @@ async def test_data_persistence_with_upsert(self, bot, update, monkeypatch):
515516
session = scoped_session("a")
516517
monkeypatch.setattr(session, "execute", self.mocked_execute)
517518
monkeypatch.setattr(session, "commit", self.mock_commit)
518-
monkeypatch.setattr(session, "close", self.mock_ses_close)
519+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
519520

520521
app = (
521522
Application.builder()
@@ -563,7 +564,7 @@ def mock_rollback():
563564
session = scoped_session("a")
564565
monkeypatch.setattr(session, "execute", mock_execute_with_error)
565566
monkeypatch.setattr(session, "commit", self.mock_commit)
566-
monkeypatch.setattr(session, "close", self.mock_ses_close)
567+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
567568
monkeypatch.setattr(session, "rollback", mock_rollback)
568569

569570
# Should not raise exception, but handle it gracefully
@@ -589,7 +590,7 @@ def mock_execute(query, params=None):
589590
session = scoped_session("a")
590591
monkeypatch.setattr(session, "execute", mock_execute)
591592
monkeypatch.setattr(session, "commit", self.mock_commit)
592-
monkeypatch.setattr(session, "close", self.mock_ses_close)
593+
monkeypatch.setattr(session, "remove", self.mock_ses_close)
593594

594595
PostgresPersistence(session=session)
595596

0 commit comments

Comments
 (0)