Skip to content

Commit cba61a9

Browse files
authored
Check SQLite3 WAL size before initialization (#14)
1 parent 33bc02a commit cba61a9

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

dissect/database/sqlite3/sqlite3.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
if TYPE_CHECKING:
2121
from collections.abc import Iterator
22+
from types import TracebackType
2223

24+
from typing_extensions import Self
2325

2426
ENCODING = {
2527
1: "utf-8",
@@ -78,13 +80,11 @@ def __init__(
7880
wal: WAL | Path | BinaryIO | None = None,
7981
checkpoint: Checkpoint | int | None = None,
8082
):
81-
# Use the provided file handle or try to open the file path.
82-
if hasattr(fh, "read"):
83-
name = getattr(fh, "name", None)
84-
path = Path(name) if name else None
85-
else:
83+
if isinstance(fh, Path):
8684
path = fh
8785
fh = path.open("rb")
86+
else:
87+
path = None
8888

8989
self.fh = fh
9090
self.path = path
@@ -105,12 +105,21 @@ def __init__(
105105
raise InvalidDatabase("Usable page size is too small")
106106

107107
if wal:
108-
self.wal = WAL(wal) if not isinstance(wal, WAL) else wal
109-
elif path:
108+
self.wal = wal if isinstance(wal, WAL) else WAL(wal)
109+
else:
110110
# Check for WAL sidecar next to the DB.
111-
wal_path = path.with_name(f"{path.name}-wal")
112-
if wal_path.exists():
113-
self.wal = WAL(wal_path)
111+
# If we have a path, we can deduce the WAL path.
112+
# If we don't have a path, we can try to get it from the file handle.
113+
if path is None:
114+
# By deducing the path at this point and not earlier, we can keep the original passed
115+
# path to indicate if we should close the file handle later on.
116+
name = getattr(fh, "name", None)
117+
path = Path(name) if name else None
118+
119+
if path is not None:
120+
wal_path = path.with_name(f"{path.name}-wal")
121+
if wal_path.exists() and wal_path.stat().st_size > 0:
122+
self.wal = WAL(wal_path)
114123

115124
# If a checkpoint index was provided, resolve it to a Checkpoint object.
116125
if self.wal and isinstance(checkpoint, int):
@@ -122,6 +131,23 @@ def __init__(
122131

123132
self.page = lru_cache(256)(self.page)
124133

134+
def __enter__(self) -> Self:
135+
"""Return ``self`` upon entering the runtime context."""
136+
return self
137+
138+
def __exit__(self, _: type[BaseException] | None, __: BaseException | None, ___: TracebackType | None) -> bool:
139+
self.close()
140+
return False
141+
142+
def close(self) -> None:
143+
"""Close the database and WAL."""
144+
# Only close DB handle if we opened it using a path
145+
if self.path is not None:
146+
self.fh.close()
147+
148+
if self.wal is not None:
149+
self.wal.close()
150+
125151
def checkpoints(self) -> Iterator[SQLite3]:
126152
"""Yield instances of the database at all available checkpoints in the WAL file, if applicable."""
127153
if not self.wal:

dissect/database/sqlite3/wal.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,13 @@
2323

2424

2525
class WAL:
26-
def __init__(self, fh: WAL | Path | BinaryIO):
26+
def __init__(self, fh: Path | BinaryIO):
2727
# Use the provided WAL file handle or try to open a sidecar WAL file.
28-
if hasattr(fh, "read"):
29-
name = getattr(fh, "name", None)
30-
path = Path(name) if name else None
31-
else:
32-
if not isinstance(fh, Path):
33-
fh = Path(fh)
28+
if isinstance(fh, Path):
3429
path = fh
3530
fh = path.open("rb")
31+
else:
32+
path = None
3633

3734
self.fh = fh
3835
self.path = path
@@ -45,6 +42,12 @@ def __init__(self, fh: WAL | Path | BinaryIO):
4542

4643
self.frame = lru_cache(1024)(self.frame)
4744

45+
def close(self) -> None:
46+
"""Close the WAL."""
47+
# Only close WAL handle if we opened it using a path
48+
if self.path is not None:
49+
self.fh.close()
50+
4851
def frame(self, frame_idx: int) -> Frame:
4952
frame_size = len(c_sqlite3.wal_frame) + self.header.page_size
5053
offset = len(c_sqlite3.wal_header) + frame_idx * frame_size

tests/sqlite3/test_sqlite3.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@
1212

1313

1414
@pytest.mark.parametrize(
15-
("db_as_path"),
16-
[pytest.param(True, id="db_as_path"), pytest.param(False, id="db_as_fh")],
15+
("open_as_path"),
16+
[pytest.param(True, id="as_path"), pytest.param(False, id="as_fh")],
1717
)
18-
def test_sqlite(sqlite_db: Path, db_as_path: bool) -> None:
19-
db = sqlite3.SQLite3(sqlite_db) if db_as_path else sqlite3.SQLite3(sqlite_db.open("rb"))
18+
def test_sqlite(sqlite_db: Path, open_as_path: bool) -> None:
19+
db = sqlite3.SQLite3(sqlite_db if open_as_path else sqlite_db.open("rb"))
20+
_assert_sqlite_db(db)
21+
db.close()
2022

23+
with sqlite3.SQLite3(sqlite_db if open_as_path else sqlite_db.open("rb")) as db:
24+
_assert_sqlite_db(db)
25+
26+
27+
def _assert_sqlite_db(db: sqlite3.SQLite3) -> None:
2128
assert db.header.magic == sqlite3.SQLITE3_HEADER_MAGIC
2229

2330
tables = list(db.tables())
@@ -67,6 +74,8 @@ def test_sqlite(sqlite_db: Path, db_as_path: bool) -> None:
6774
assert table.row(0).__dict__ == rows[0].__dict__
6875
assert list(rows[0]) == [("id", 1), ("name", "testing"), ("value", 1337)]
6976

77+
db.close()
78+
7079

7180
@pytest.mark.parametrize(
7281
("input", "encoding", "expected_output"),

tests/sqlite3/test_wal.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,26 @@ def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_
2626
)
2727
_assert_checkpoint_1(db)
2828

29+
db.close()
30+
2931
db = sqlite3.SQLite3(
3032
sqlite_db if db_as_path else sqlite_db.open("rb"),
3133
sqlite_wal if wal_as_path else sqlite_wal.open("rb"),
3234
checkpoint=2,
3335
)
3436
_assert_checkpoint_2(db)
3537

38+
db.close()
39+
3640
db = sqlite3.SQLite3(
3741
sqlite_db if db_as_path else sqlite_db.open("rb"),
3842
sqlite_wal if wal_as_path else sqlite_wal.open("rb"),
3943
checkpoint=3,
4044
)
4145
_assert_checkpoint_3(db)
4246

47+
db.close()
48+
4349

4450
def _assert_checkpoint_1(s: sqlite3.SQLite3) -> None:
4551
# After the first checkpoint the "after checkpoint" entries are present

0 commit comments

Comments
 (0)