44of concerns and proper type safety.
55"""
66
7+ import hashlib
78import inspect
9+ import re
810import time
911from abc import ABC , abstractmethod
1012from pathlib import Path
1517from sqlspec .migrations .loaders import get_migration_loader
1618from sqlspec .utils .logging import get_logger
1719from sqlspec .utils .sync_tools import async_ , await_
20+ from sqlspec .utils .version import parse_version
1821
1922if TYPE_CHECKING :
2023 from collections .abc import Awaitable , Callable , Coroutine
2629logger = get_logger ("migrations.runner" )
2730
2831
32+ class _CachedMigrationMetadata :
33+ """Cached migration metadata keyed by file path."""
34+
35+ __slots__ = ("metadata" , "mtime_ns" , "size" )
36+
37+ def __init__ (self , metadata : "dict[str, Any]" , mtime_ns : int , size : int ) -> None :
38+ self .metadata = metadata
39+ self .mtime_ns = mtime_ns
40+ self .size = size
41+
42+ def clone (self ) -> "dict[str, Any]" :
43+ return dict (self .metadata )
44+
45+
46+ class _MigrationFileEntry :
47+ """Represents a migration file discovered during directory scanning."""
48+
49+ __slots__ = ("extension_name" , "path" )
50+
51+ def __init__ (self , path : Path , extension_name : "str | None" ) -> None :
52+ self .path = path
53+ self .extension_name = extension_name
54+
55+
2956class BaseMigrationRunner (ABC ):
3057 """Base migration runner with common functionality shared between sync and async implementations."""
3158
@@ -52,6 +79,100 @@ def __init__(
5279 self .project_root : Path | None = None
5380 self .context = context
5481 self .extension_configs = extension_configs or {}
82+ self ._listing_digest : str | None = None
83+ self ._listing_cache : list [tuple [str , Path ]] | None = None
84+ self ._listing_signatures : dict [str , tuple [int , int ]] = {}
85+ self ._metadata_cache : dict [str , _CachedMigrationMetadata ] = {}
86+
87+ def _iter_directory_entries (self , base_path : Path , extension_name : "str | None" ) -> "list[_MigrationFileEntry]" :
88+ """Collect migration files discovered under a base path."""
89+
90+ if not base_path .exists ():
91+ return []
92+
93+ entries : list [_MigrationFileEntry ] = []
94+ for pattern in ("*.sql" , "*.py" ):
95+ for file_path in sorted (base_path .glob (pattern )):
96+ if file_path .name .startswith ("." ):
97+ continue
98+ entries .append (_MigrationFileEntry (path = file_path , extension_name = extension_name ))
99+ return entries
100+
101+ def _collect_listing_entries (self ) -> "tuple[list[_MigrationFileEntry], dict[str, tuple[int, int]], str]" :
102+ """Gather migration files, stat signatures, and digest for cache validation."""
103+
104+ entries : list [_MigrationFileEntry ] = []
105+ signatures : dict [str , tuple [int , int ]] = {}
106+ digest_source = hashlib .md5 (usedforsecurity = False )
107+
108+ for entry in self ._iter_directory_entries (self .migrations_path , None ):
109+ self ._record_entry (entry , entries , signatures , digest_source )
110+
111+ for ext_name , ext_path in self .extension_migrations .items ():
112+ for entry in self ._iter_directory_entries (ext_path , ext_name ):
113+ self ._record_entry (entry , entries , signatures , digest_source )
114+
115+ return entries , signatures , digest_source .hexdigest ()
116+
117+ def _record_entry (
118+ self ,
119+ entry : _MigrationFileEntry ,
120+ entries : "list[_MigrationFileEntry]" ,
121+ signatures : "dict[str, tuple[int, int]]" ,
122+ digest_source : Any ,
123+ ) -> None :
124+ """Record entry metadata for cache decisions."""
125+
126+ try :
127+ stat_result = entry .path .stat ()
128+ except FileNotFoundError :
129+ return
130+
131+ path_str = str (entry .path )
132+ token = (stat_result .st_mtime_ns , stat_result .st_size )
133+ signatures [path_str ] = token
134+ digest_source .update (path_str .encode ("utf-8" ))
135+ digest_source .update (f"{ token [0 ]} :{ token [1 ]} " .encode ())
136+ entries .append (entry )
137+
138+ def _build_sorted_listing (self , entries : "list[_MigrationFileEntry]" ) -> "list[tuple[str, Path]]" :
139+ """Construct sorted migration listing from directory entries."""
140+
141+ migrations : list [tuple [str , Path ]] = []
142+
143+ for entry in entries :
144+ version = self ._extract_version (entry .path .name )
145+ if not version :
146+ continue
147+ if entry .extension_name :
148+ version = f"ext_{ entry .extension_name } _{ version } "
149+ migrations .append ((version , entry .path ))
150+
151+ def version_sort_key (migration_tuple : "tuple[str, Path]" ) -> "Any" :
152+ version_str = migration_tuple [0 ]
153+ try :
154+ return parse_version (version_str )
155+ except ValueError :
156+ return version_str
157+
158+ return sorted (migrations , key = version_sort_key )
159+
160+ def _log_listing_invalidation (
161+ self , previous : "dict[str, tuple[int, int]]" , current : "dict[str, tuple[int, int]]"
162+ ) -> None :
163+ """Log cache invalidation details at INFO level."""
164+
165+ prev_keys = set (previous )
166+ curr_keys = set (current )
167+ added = curr_keys - prev_keys
168+ removed = prev_keys - curr_keys
169+ modified = {key for key in prev_keys & curr_keys if previous [key ] != current [key ]}
170+ logger .info (
171+ "Migration listing cache invalidated (added=%d, removed=%d, modified=%d)" ,
172+ len (added ),
173+ len (removed ),
174+ len (modified ),
175+ )
55176
56177 def _extract_version (self , filename : str ) -> "str | None" :
57178 """Extract version from filename.
@@ -95,9 +216,6 @@ def _calculate_checksum(self, content: str) -> str:
95216 Returns:
96217 MD5 checksum hex string.
97218 """
98- import hashlib
99- import re
100-
101219 canonical_content = re .sub (r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$" , "" , content , flags = re .MULTILINE )
102220
103221 return hashlib .md5 (canonical_content .encode ()).hexdigest () # noqa: S324
@@ -114,57 +232,33 @@ def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[
114232 For async implementations, returns a coroutine.
115233 """
116234
117- def _get_migration_files_sync (self ) -> "list[tuple[str, Path]]" :
118- """Get all migration files sorted by version.
235+ def _load_migration_listing (self ) -> "list[tuple[str, Path]]" :
236+ """Build the cached migration listing shared by sync/async runners."""
237+ entries , signatures , digest = self ._collect_listing_entries ()
238+ cached_listing = self ._listing_cache
119239
120- Returns:
121- List of tuples containing (version, file_path).
122- """
123-
124- migrations = []
240+ if cached_listing is not None and self ._listing_digest == digest :
241+ logger .debug ("Migration listing cache hit (%d files)" , len (cached_listing ))
242+ return cached_listing
125243
126- # Scan primary migration path
127- if self .migrations_path .exists ():
128- for pattern in ("*.sql" , "*.py" ):
129- for file_path in self .migrations_path .glob (pattern ):
130- if file_path .name .startswith ("." ):
131- continue
132- version = self ._extract_version (file_path .name )
133- if version :
134- migrations .append ((version , file_path ))
244+ files = self ._build_sorted_listing (entries )
245+ previous_digest = self ._listing_digest
246+ previous_signatures = self ._listing_signatures
135247
136- # Scan extension migration paths
137- for ext_name , ext_path in self .extension_migrations .items ():
138- if ext_path .exists ():
139- for pattern in ("*.sql" , "*.py" ):
140- for file_path in ext_path .glob (pattern ):
141- if file_path .name .startswith ("." ):
142- continue
143- # Prefix extension migrations to avoid version conflicts
144- version = self ._extract_version (file_path .name )
145- if version :
146- # Use ext_ prefix to distinguish extension migrations
147- prefixed_version = f"ext_{ ext_name } _{ version } "
148- migrations .append ((prefixed_version , file_path ))
149-
150- from sqlspec .utils .version import parse_version
248+ self ._listing_cache = files
249+ self ._listing_signatures = signatures
250+ self ._listing_digest = digest
151251
152- def version_sort_key (migration_tuple : "tuple[str, Path]" ) -> "Any" :
153- version_str = migration_tuple [0 ]
154- try :
155- return parse_version (version_str )
156- except ValueError :
157- return version_str
252+ if previous_digest is None :
253+ logger .debug ("Primed migration listing cache with %d files" , len (files ))
254+ else :
255+ self ._log_listing_invalidation (previous_signatures , signatures )
158256
159- return sorted ( migrations , key = version_sort_key )
257+ return files
160258
161- def get_migration_files (self ) -> "list[tuple[str, Path]]" :
162- """Get all migration files sorted by version.
163-
164- Returns:
165- List of (version, path) tuples sorted by version.
166- """
167- return self ._get_migration_files_sync ()
259+ @abstractmethod
260+ def get_migration_files (self ) -> "list[tuple[str, Path]] | Awaitable[list[tuple[str, Path]]]" :
261+ """Get all migration files sorted by version."""
168262
169263 def _load_migration_metadata_common (self , file_path : Path , version : "str | None" = None ) -> "dict[str, Any]" :
170264 """Load common migration metadata that doesn't require async operations.
@@ -176,7 +270,18 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None"
176270 Returns:
177271 Partial migration metadata dictionary.
178272 """
179- import re
273+ cache_key = str (file_path )
274+ stat_result = file_path .stat ()
275+ cached_metadata = self ._metadata_cache .get (cache_key )
276+ if (
277+ cached_metadata
278+ and cached_metadata .mtime_ns == stat_result .st_mtime_ns
279+ and cached_metadata .size == stat_result .st_size
280+ ):
281+ logger .debug ("Migration metadata cache hit: %s" , cache_key )
282+ metadata = cached_metadata .clone ()
283+ metadata ["file_path" ] = file_path
284+ return metadata
180285
181286 content = file_path .read_text (encoding = "utf-8" )
182287 checksum = self ._calculate_checksum (content )
@@ -191,14 +296,22 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None"
191296 if transactional_match :
192297 transactional = transactional_match .group (1 ).lower () == "true"
193298
194- return {
299+ metadata = {
195300 "version" : version ,
196301 "description" : description ,
197302 "file_path" : file_path ,
198303 "checksum" : checksum ,
199304 "content" : content ,
200305 "transactional" : transactional ,
201306 }
307+ self ._metadata_cache [cache_key ] = _CachedMigrationMetadata (
308+ metadata = dict (metadata ), mtime_ns = stat_result .st_mtime_ns , size = stat_result .st_size
309+ )
310+ if cached_metadata :
311+ logger .info ("Migration metadata cache invalidated: %s" , cache_key )
312+ else :
313+ logger .debug ("Cached migration metadata: %s" , cache_key )
314+ return metadata
202315
203316 def _get_context_for_migration (self , file_path : Path ) -> "MigrationContext | None" :
204317 """Get the appropriate context for a migration file.
@@ -263,6 +376,14 @@ def should_use_transaction(self, migration: "dict[str, Any]", config: Any) -> bo
263376class SyncMigrationRunner (BaseMigrationRunner ):
264377 """Synchronous migration runner with pure sync methods."""
265378
379+ def get_migration_files (self ) -> "list[tuple[str, Path]]" :
380+ """Get all migration files sorted by version.
381+
382+ Returns:
383+ List of (version, path) tuples sorted by version.
384+ """
385+ return self ._load_migration_listing ()
386+
266387 def load_migration (self , file_path : Path , version : "str | None" = None ) -> "dict[str, Any]" :
267388 """Load a migration file and extract its components.
268389
@@ -287,7 +408,7 @@ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict
287408 has_upgrade , has_downgrade = self .loader .has_query (up_query ), self .loader .has_query (down_query )
288409 else :
289410 try :
290- has_downgrade = bool (self ._get_migration_sql_sync ({"loader" : loader , "file_path" : file_path }, "down" ))
411+ has_downgrade = bool (self ._get_migration_sql ({"loader" : loader , "file_path" : file_path }, "down" ))
291412 except Exception :
292413 has_downgrade = False
293414
@@ -313,7 +434,7 @@ def execute_upgrade(
313434 Returns:
314435 Tuple of (sql_content, execution_time_ms).
315436 """
316- upgrade_sql_list = self ._get_migration_sql_sync (migration , "up" )
437+ upgrade_sql_list = self ._get_migration_sql (migration , "up" )
317438 if upgrade_sql_list is None :
318439 return None , 0
319440
@@ -365,7 +486,7 @@ def execute_downgrade(
365486 Returns:
366487 Tuple of (sql_content, execution_time_ms).
367488 """
368- downgrade_sql_list = self ._get_migration_sql_sync (migration , "down" )
489+ downgrade_sql_list = self ._get_migration_sql (migration , "down" )
369490 if downgrade_sql_list is None :
370491 return None , 0
371492
@@ -398,7 +519,7 @@ def execute_downgrade(
398519
399520 return None , execution_time
400521
401- def _get_migration_sql_sync (self , migration : "dict[str, Any]" , direction : str ) -> "list[str] | None" :
522+ def _get_migration_sql (self , migration : "dict[str, Any]" , direction : str ) -> "list[str] | None" :
402523 """Get migration SQL for given direction (sync version).
403524
404525 Args:
@@ -475,13 +596,13 @@ def load_all_migrations(self) -> "dict[str, SQL]":
475596class AsyncMigrationRunner (BaseMigrationRunner ):
476597 """Asynchronous migration runner with pure async methods."""
477598
478- async def get_migration_files (self ) -> "list[tuple[str, Path]]" : # type: ignore[override]
599+ async def get_migration_files (self ) -> "list[tuple[str, Path]]" :
479600 """Get all migration files sorted by version.
480601
481602 Returns:
482603 List of (version, path) tuples sorted by version.
483604 """
484- return self ._get_migration_files_sync ()
605+ return await async_ ( self ._load_migration_listing ) ()
485606
486607 async def load_migration (self , file_path : Path , version : "str | None" = None ) -> "dict[str, Any]" :
487608 """Load a migration file and extract its components.
0 commit comments