Skip to content

Commit 4d79e1c

Browse files
committed
Enable foreign keys and WAL mode by default on SQLite
1 parent af063d4 commit 4d79e1c

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

mautrix/util/async_db/aiosqlite.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,43 @@ def __init__(
110110
self._db_args.pop("max_size", None)
111111
self._stopped = False
112112
self._conns = 0
113-
self._init_commands = self._db_args.pop("init_commands", [])
113+
self._init_commands = self._add_missing_pragmas(self._db_args.pop("init_commands", []))
114+
115+
@staticmethod
116+
def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
117+
has_foreign_keys = False
118+
has_journal_mode = False
119+
has_busy_timeout = False
120+
for cmd in init_commands:
121+
if "PRAGMA" not in cmd:
122+
continue
123+
if "foreign_keys" in cmd:
124+
has_foreign_keys = True
125+
elif "journal_mode" in cmd:
126+
has_journal_mode = True
127+
elif "busy_timeout" in cmd:
128+
has_busy_timeout = True
129+
if not has_foreign_keys:
130+
init_commands.append("PRAGMA foreign_keys = ON")
131+
if not has_journal_mode:
132+
init_commands.append("PRAGMA journal_mode = WAL")
133+
if not has_busy_timeout:
134+
init_commands.append("PRAGMA busy_timeout = 5000")
135+
return init_commands
114136

115137
async def start(self) -> None:
116138
if self._conns:
117139
raise RuntimeError("database pool has already been started")
118140
elif self._stopped:
119141
raise RuntimeError("database pool can't be restarted")
120142
self.log.debug(f"Connecting to {self.url}")
143+
self.log.debug(f"Database connection init commands: {self._init_commands}")
121144
for _ in range(self._pool.maxsize):
122145
conn = await TxnConnection(self._path, **self._db_args)
123146
if self._init_commands:
124147
cur = await conn.cursor()
125148
for command in self._init_commands:
126-
self.log.debug("Executing command: %s", command)
149+
self.log.trace("Executing init command: %s", command)
127150
await cur.execute(command)
128151
await conn.commit()
129152
conn.row_factory = sqlite3.Row

0 commit comments

Comments
 (0)