Skip to content

Commit 83cb6b5

Browse files
authored
feat: enhance caching mechanism for migration metadata and add tests (#228)
Improve the caching mechanism for migration metadata to reduce redundant checksum calculations and ensure cache invalidation when file content changes. Include tests to verify the functionality of the caching behavior.
1 parent db22434 commit 83cb6b5

File tree

2 files changed

+240
-56
lines changed

2 files changed

+240
-56
lines changed

sqlspec/migrations/runner.py

Lines changed: 177 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
of concerns and proper type safety.
55
"""
66

7+
import hashlib
78
import inspect
9+
import re
810
import time
911
from abc import ABC, abstractmethod
1012
from pathlib import Path
@@ -15,6 +17,7 @@
1517
from sqlspec.migrations.loaders import get_migration_loader
1618
from sqlspec.utils.logging import get_logger
1719
from sqlspec.utils.sync_tools import async_, await_
20+
from sqlspec.utils.version import parse_version
1821

1922
if TYPE_CHECKING:
2023
from collections.abc import Awaitable, Callable, Coroutine
@@ -26,6 +29,30 @@
2629
logger = get_logger("migrations.runner")
2730

2831

32+
class _CachedMigrationMetadata:
33+
"""Cached migration metadata keyed by file path."""
34+
35+
__slots__ = ("metadata", "mtime_ns", "size")
36+
37+
def __init__(self, metadata: "dict[str, Any]", mtime_ns: int, size: int) -> None:
38+
self.metadata = metadata
39+
self.mtime_ns = mtime_ns
40+
self.size = size
41+
42+
def clone(self) -> "dict[str, Any]":
43+
return dict(self.metadata)
44+
45+
46+
class _MigrationFileEntry:
47+
"""Represents a migration file discovered during directory scanning."""
48+
49+
__slots__ = ("extension_name", "path")
50+
51+
def __init__(self, path: Path, extension_name: "str | None") -> None:
52+
self.path = path
53+
self.extension_name = extension_name
54+
55+
2956
class BaseMigrationRunner(ABC):
3057
"""Base migration runner with common functionality shared between sync and async implementations."""
3158

@@ -52,6 +79,100 @@ def __init__(
5279
self.project_root: Path | None = None
5380
self.context = context
5481
self.extension_configs = extension_configs or {}
82+
self._listing_digest: str | None = None
83+
self._listing_cache: list[tuple[str, Path]] | None = None
84+
self._listing_signatures: dict[str, tuple[int, int]] = {}
85+
self._metadata_cache: dict[str, _CachedMigrationMetadata] = {}
86+
87+
def _iter_directory_entries(self, base_path: Path, extension_name: "str | None") -> "list[_MigrationFileEntry]":
88+
"""Collect migration files discovered under a base path."""
89+
90+
if not base_path.exists():
91+
return []
92+
93+
entries: list[_MigrationFileEntry] = []
94+
for pattern in ("*.sql", "*.py"):
95+
for file_path in sorted(base_path.glob(pattern)):
96+
if file_path.name.startswith("."):
97+
continue
98+
entries.append(_MigrationFileEntry(path=file_path, extension_name=extension_name))
99+
return entries
100+
101+
def _collect_listing_entries(self) -> "tuple[list[_MigrationFileEntry], dict[str, tuple[int, int]], str]":
102+
"""Gather migration files, stat signatures, and digest for cache validation."""
103+
104+
entries: list[_MigrationFileEntry] = []
105+
signatures: dict[str, tuple[int, int]] = {}
106+
digest_source = hashlib.md5(usedforsecurity=False)
107+
108+
for entry in self._iter_directory_entries(self.migrations_path, None):
109+
self._record_entry(entry, entries, signatures, digest_source)
110+
111+
for ext_name, ext_path in self.extension_migrations.items():
112+
for entry in self._iter_directory_entries(ext_path, ext_name):
113+
self._record_entry(entry, entries, signatures, digest_source)
114+
115+
return entries, signatures, digest_source.hexdigest()
116+
117+
def _record_entry(
118+
self,
119+
entry: _MigrationFileEntry,
120+
entries: "list[_MigrationFileEntry]",
121+
signatures: "dict[str, tuple[int, int]]",
122+
digest_source: Any,
123+
) -> None:
124+
"""Record entry metadata for cache decisions."""
125+
126+
try:
127+
stat_result = entry.path.stat()
128+
except FileNotFoundError:
129+
return
130+
131+
path_str = str(entry.path)
132+
token = (stat_result.st_mtime_ns, stat_result.st_size)
133+
signatures[path_str] = token
134+
digest_source.update(path_str.encode("utf-8"))
135+
digest_source.update(f"{token[0]}:{token[1]}".encode())
136+
entries.append(entry)
137+
138+
def _build_sorted_listing(self, entries: "list[_MigrationFileEntry]") -> "list[tuple[str, Path]]":
139+
"""Construct sorted migration listing from directory entries."""
140+
141+
migrations: list[tuple[str, Path]] = []
142+
143+
for entry in entries:
144+
version = self._extract_version(entry.path.name)
145+
if not version:
146+
continue
147+
if entry.extension_name:
148+
version = f"ext_{entry.extension_name}_{version}"
149+
migrations.append((version, entry.path))
150+
151+
def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
152+
version_str = migration_tuple[0]
153+
try:
154+
return parse_version(version_str)
155+
except ValueError:
156+
return version_str
157+
158+
return sorted(migrations, key=version_sort_key)
159+
160+
def _log_listing_invalidation(
161+
self, previous: "dict[str, tuple[int, int]]", current: "dict[str, tuple[int, int]]"
162+
) -> None:
163+
"""Log cache invalidation details at INFO level."""
164+
165+
prev_keys = set(previous)
166+
curr_keys = set(current)
167+
added = curr_keys - prev_keys
168+
removed = prev_keys - curr_keys
169+
modified = {key for key in prev_keys & curr_keys if previous[key] != current[key]}
170+
logger.info(
171+
"Migration listing cache invalidated (added=%d, removed=%d, modified=%d)",
172+
len(added),
173+
len(removed),
174+
len(modified),
175+
)
55176

56177
def _extract_version(self, filename: str) -> "str | None":
57178
"""Extract version from filename.
@@ -95,9 +216,6 @@ def _calculate_checksum(self, content: str) -> str:
95216
Returns:
96217
MD5 checksum hex string.
97218
"""
98-
import hashlib
99-
import re
100-
101219
canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE)
102220

103221
return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324
@@ -114,57 +232,33 @@ def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[
114232
For async implementations, returns a coroutine.
115233
"""
116234

117-
def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
118-
"""Get all migration files sorted by version.
235+
def _load_migration_listing(self) -> "list[tuple[str, Path]]":
236+
"""Build the cached migration listing shared by sync/async runners."""
237+
entries, signatures, digest = self._collect_listing_entries()
238+
cached_listing = self._listing_cache
119239

120-
Returns:
121-
List of tuples containing (version, file_path).
122-
"""
123-
124-
migrations = []
240+
if cached_listing is not None and self._listing_digest == digest:
241+
logger.debug("Migration listing cache hit (%d files)", len(cached_listing))
242+
return cached_listing
125243

126-
# Scan primary migration path
127-
if self.migrations_path.exists():
128-
for pattern in ("*.sql", "*.py"):
129-
for file_path in self.migrations_path.glob(pattern):
130-
if file_path.name.startswith("."):
131-
continue
132-
version = self._extract_version(file_path.name)
133-
if version:
134-
migrations.append((version, file_path))
244+
files = self._build_sorted_listing(entries)
245+
previous_digest = self._listing_digest
246+
previous_signatures = self._listing_signatures
135247

136-
# Scan extension migration paths
137-
for ext_name, ext_path in self.extension_migrations.items():
138-
if ext_path.exists():
139-
for pattern in ("*.sql", "*.py"):
140-
for file_path in ext_path.glob(pattern):
141-
if file_path.name.startswith("."):
142-
continue
143-
# Prefix extension migrations to avoid version conflicts
144-
version = self._extract_version(file_path.name)
145-
if version:
146-
# Use ext_ prefix to distinguish extension migrations
147-
prefixed_version = f"ext_{ext_name}_{version}"
148-
migrations.append((prefixed_version, file_path))
149-
150-
from sqlspec.utils.version import parse_version
248+
self._listing_cache = files
249+
self._listing_signatures = signatures
250+
self._listing_digest = digest
151251

152-
def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
153-
version_str = migration_tuple[0]
154-
try:
155-
return parse_version(version_str)
156-
except ValueError:
157-
return version_str
252+
if previous_digest is None:
253+
logger.debug("Primed migration listing cache with %d files", len(files))
254+
else:
255+
self._log_listing_invalidation(previous_signatures, signatures)
158256

159-
return sorted(migrations, key=version_sort_key)
257+
return files
160258

161-
def get_migration_files(self) -> "list[tuple[str, Path]]":
162-
"""Get all migration files sorted by version.
163-
164-
Returns:
165-
List of (version, path) tuples sorted by version.
166-
"""
167-
return self._get_migration_files_sync()
259+
@abstractmethod
260+
def get_migration_files(self) -> "list[tuple[str, Path]] | Awaitable[list[tuple[str, Path]]]":
261+
"""Get all migration files sorted by version."""
168262

169263
def _load_migration_metadata_common(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
170264
"""Load common migration metadata that doesn't require async operations.
@@ -176,7 +270,18 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None"
176270
Returns:
177271
Partial migration metadata dictionary.
178272
"""
179-
import re
273+
cache_key = str(file_path)
274+
stat_result = file_path.stat()
275+
cached_metadata = self._metadata_cache.get(cache_key)
276+
if (
277+
cached_metadata
278+
and cached_metadata.mtime_ns == stat_result.st_mtime_ns
279+
and cached_metadata.size == stat_result.st_size
280+
):
281+
logger.debug("Migration metadata cache hit: %s", cache_key)
282+
metadata = cached_metadata.clone()
283+
metadata["file_path"] = file_path
284+
return metadata
180285

181286
content = file_path.read_text(encoding="utf-8")
182287
checksum = self._calculate_checksum(content)
@@ -191,14 +296,22 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None"
191296
if transactional_match:
192297
transactional = transactional_match.group(1).lower() == "true"
193298

194-
return {
299+
metadata = {
195300
"version": version,
196301
"description": description,
197302
"file_path": file_path,
198303
"checksum": checksum,
199304
"content": content,
200305
"transactional": transactional,
201306
}
307+
self._metadata_cache[cache_key] = _CachedMigrationMetadata(
308+
metadata=dict(metadata), mtime_ns=stat_result.st_mtime_ns, size=stat_result.st_size
309+
)
310+
if cached_metadata:
311+
logger.info("Migration metadata cache invalidated: %s", cache_key)
312+
else:
313+
logger.debug("Cached migration metadata: %s", cache_key)
314+
return metadata
202315

203316
def _get_context_for_migration(self, file_path: Path) -> "MigrationContext | None":
204317
"""Get the appropriate context for a migration file.
@@ -263,6 +376,14 @@ def should_use_transaction(self, migration: "dict[str, Any]", config: Any) -> bo
263376
class SyncMigrationRunner(BaseMigrationRunner):
264377
"""Synchronous migration runner with pure sync methods."""
265378

379+
def get_migration_files(self) -> "list[tuple[str, Path]]":
380+
"""Get all migration files sorted by version.
381+
382+
Returns:
383+
List of (version, path) tuples sorted by version.
384+
"""
385+
return self._load_migration_listing()
386+
266387
def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
267388
"""Load a migration file and extract its components.
268389
@@ -287,7 +408,7 @@ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict
287408
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
288409
else:
289410
try:
290-
has_downgrade = bool(self._get_migration_sql_sync({"loader": loader, "file_path": file_path}, "down"))
411+
has_downgrade = bool(self._get_migration_sql({"loader": loader, "file_path": file_path}, "down"))
291412
except Exception:
292413
has_downgrade = False
293414

@@ -313,7 +434,7 @@ def execute_upgrade(
313434
Returns:
314435
Tuple of (sql_content, execution_time_ms).
315436
"""
316-
upgrade_sql_list = self._get_migration_sql_sync(migration, "up")
437+
upgrade_sql_list = self._get_migration_sql(migration, "up")
317438
if upgrade_sql_list is None:
318439
return None, 0
319440

@@ -365,7 +486,7 @@ def execute_downgrade(
365486
Returns:
366487
Tuple of (sql_content, execution_time_ms).
367488
"""
368-
downgrade_sql_list = self._get_migration_sql_sync(migration, "down")
489+
downgrade_sql_list = self._get_migration_sql(migration, "down")
369490
if downgrade_sql_list is None:
370491
return None, 0
371492

@@ -398,7 +519,7 @@ def execute_downgrade(
398519

399520
return None, execution_time
400521

401-
def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
522+
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
402523
"""Get migration SQL for given direction (sync version).
403524
404525
Args:
@@ -475,13 +596,13 @@ def load_all_migrations(self) -> "dict[str, SQL]":
475596
class AsyncMigrationRunner(BaseMigrationRunner):
476597
"""Asynchronous migration runner with pure async methods."""
477598

478-
async def get_migration_files(self) -> "list[tuple[str, Path]]": # type: ignore[override]
599+
async def get_migration_files(self) -> "list[tuple[str, Path]]":
479600
"""Get all migration files sorted by version.
480601
481602
Returns:
482603
List of (version, path) tuples sorted by version.
483604
"""
484-
return self._get_migration_files_sync()
605+
return await async_(self._load_migration_listing)()
485606

486607
async def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
487608
"""Load a migration file and extract its components.

0 commit comments

Comments
 (0)