Skip to content

Commit bb6fa6e

Browse files
committed
feat: cache corrections
1 parent 7bf399b commit bb6fa6e

File tree

3 files changed

+121
-1
lines changed

3 files changed

+121
-1
lines changed

sqlspec/driver/_common.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from sqlspec.parameters.types import TypedParameter
1111
from sqlspec.statement import SQLResult, Statement, StatementFilter
1212
from sqlspec.statement.builder import QueryBuilder
13+
from sqlspec.statement.cache import get_cache_config, sql_cache
1314
from sqlspec.statement.pipeline import SQLTransformContext, create_pipeline_from_config
1415
from sqlspec.statement.result import OperationType
1516
from sqlspec.statement.splitter import split_sql_script
1617
from sqlspec.statement.sql import SQL, StatementConfig
1718
from sqlspec.utils.logging import get_logger
19+
from sqlspec.utils.statement_hashing import hash_sql_statement
1820

1921
if TYPE_CHECKING:
2022
from sqlspec.typing import StatementParameters
@@ -469,7 +471,7 @@ def _apply_pipeline_transformations(
469471
def _get_compiled_sql(
470472
self, statement: "SQL", statement_config: "StatementConfig", flatten_single_params: bool = False
471473
) -> tuple[str, Any]:
472-
"""Get compiled SQL with optimal parameter style (only converts when needed).
474+
"""Get compiled SQL with optimal parameter style and caching support.
473475
474476
Args:
475477
statement: SQL statement to compile
@@ -479,6 +481,26 @@ def _get_compiled_sql(
479481
Returns:
480482
Tuple of (compiled_sql, parameters)
481483
"""
484+
485+
# Check if caching is enabled
486+
cache_config = get_cache_config()
487+
cache_key = None
488+
if cache_config.compiled_cache_enabled and statement_config.enable_caching:
489+
# Generate cache key that includes parameter style context
490+
cache_key = self._generate_compilation_cache_key(statement, statement_config, flatten_single_params)
491+
492+
# Try cache first
493+
cached_result = sql_cache.get(cache_key)
494+
if cached_result is not None:
495+
sql, params = cached_result
496+
# Apply driver parameter preparation to cached params
497+
prepared_params = self.prepare_driver_parameters(params, statement_config, is_many=statement.is_many)
498+
# Apply output_transformer if configured
499+
if statement_config.parameter_config.output_transformer:
500+
sql, prepared_params = statement_config.parameter_config.output_transformer(sql, prepared_params)
501+
return sql, prepared_params
502+
503+
# Determine target parameter style (existing logic)
482504
if statement.is_script and not statement_config.parameter_config.needs_static_script_compilation:
483505
target_style = ParameterStyle.STATIC
484506
elif statement_config.parameter_config.supported_execution_parameter_styles is not None:
@@ -496,7 +518,15 @@ def _get_compiled_sql(
496518
else:
497519
# No execution style configuration, use default parameter style for explicit compilation
498520
target_style = statement_config.parameter_config.default_parameter_style
521+
522+
# Compile the SQL
499523
sql, params = statement.compile(placeholder_style=target_style, flatten_single_params=flatten_single_params)
524+
525+
# Cache the compilation result if caching is enabled (before driver preparation)
526+
if cache_key is not None:
527+
sql_cache.set(cache_key, (sql, params))
528+
529+
# Prepare parameters for driver
500530
prepared_params = self.prepare_driver_parameters(params, statement_config, is_many=statement.is_many)
501531

502532
# Apply output_transformer if configured
@@ -505,6 +535,32 @@ def _get_compiled_sql(
505535

506536
return sql, prepared_params
507537

538+
def _generate_compilation_cache_key(
539+
self, statement: "SQL", config: "StatementConfig", flatten_single_params: bool
540+
) -> str:
541+
"""Generate cache key that includes all compilation context.
542+
543+
This method creates a deterministic cache key that includes all factors
544+
that affect SQL compilation, preventing cache contamination between
545+
different compilation contexts.
546+
"""
547+
548+
# Include all factors that affect compilation
549+
context_hash = hash(
550+
(
551+
config.parameter_config.hash(),
552+
config.dialect,
553+
statement.is_script,
554+
statement.is_many,
555+
flatten_single_params,
556+
bool(config.parameter_config.output_transformer),
557+
bool(config.parameter_config.needs_static_script_compilation),
558+
)
559+
)
560+
561+
base_hash = hash_sql_statement(statement)
562+
return f"compiled:{base_hash}:{context_hash}"
563+
508564
def _create_count_query(self, original_sql: "SQL") -> "SQL":
509565
"""Create a COUNT query from the original SQL statement.
510566

sqlspec/parameters/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,26 @@ def __init__(
4848
self.output_transformer = output_transformer
4949
self.needs_static_script_compilation = needs_static_script_compilation
5050
self.allow_mixed_parameter_styles = allow_mixed_parameter_styles
51+
52+
def hash(self) -> int:
53+
"""Generate hash for cache key generation.
54+
55+
This method creates a deterministic hash of the parameter configuration
56+
for use in cache keys, ensuring different parameter configurations
57+
don't share cache entries.
58+
"""
59+
# Create tuple of all configuration values that affect compilation
60+
config_tuple = (
61+
self.default_parameter_style.value if self.default_parameter_style else None,
62+
tuple(sorted(s.value for s in self.supported_parameter_styles)) if self.supported_parameter_styles else (),
63+
tuple(sorted(s.value for s in self.supported_execution_parameter_styles))
64+
if self.supported_execution_parameter_styles
65+
else (),
66+
self.default_execution_parameter_style.value if self.default_execution_parameter_style else None,
67+
self.has_native_list_expansion,
68+
bool(self.output_transformer), # Don't hash the function object itself
69+
self.needs_static_script_compilation,
70+
self.allow_mixed_parameter_styles,
71+
tuple(sorted(str(k) for k in self.type_coercion_map)) if self.type_coercion_map else (),
72+
)
73+
return hash(config_tuple)

sqlspec/statement/sql.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from sqlspec.typing import Empty
3434
from sqlspec.utils.logging import get_logger
35+
from sqlspec.utils.statement_hashing import hash_sql_statement
3536
from sqlspec.utils.type_guards import (
3637
can_append_to_statement,
3738
can_extract_parameters,
@@ -1690,6 +1691,46 @@ def _convert_named_pyformat_parameters(self, params: dict[str, Any], norm_state:
16901691

16911692
return params
16921693

1694+
def generate_cache_key_with_config(self, config: "Optional[StatementConfig]" = None) -> str:
1695+
"""Generate cache key that includes StatementConfig context.
1696+
1697+
This method creates a deterministic cache key that includes both the SQL content
1698+
and the StatementConfig settings to prevent cross-contamination between different
1699+
configurations.
1700+
1701+
Args:
1702+
config: Optional StatementConfig to use for key generation.
1703+
Uses self.statement_config if not provided.
1704+
1705+
Returns:
1706+
String cache key that includes both SQL and configuration context
1707+
"""
1708+
1709+
effective_config = config or self.statement_config
1710+
1711+
# Create hash of configuration values that affect processing
1712+
config_hash = hash(
1713+
(
1714+
effective_config.enable_parsing,
1715+
effective_config.enable_validation,
1716+
effective_config.enable_transformations,
1717+
effective_config.enable_analysis,
1718+
effective_config.enable_expression_simplification,
1719+
effective_config.enable_parameter_type_wrapping,
1720+
effective_config.enable_caching,
1721+
effective_config.dialect,
1722+
effective_config.parameter_config.hash(),
1723+
tuple(effective_config.pre_process_steps) if effective_config.pre_process_steps else (),
1724+
tuple(effective_config.post_process_steps) if effective_config.post_process_steps else (),
1725+
)
1726+
)
1727+
1728+
# Include filter context in the cache key
1729+
filter_hash = hash(tuple(str(f) for f in self._filters)) if self._filters else 0
1730+
1731+
base_hash = hash_sql_statement(self)
1732+
return f"sql:{base_hash}:{config_hash}:{filter_hash}"
1733+
16931734
def _apply_placeholder_style(self, sql: "str", params: Any, placeholder_style: "str") -> "tuple[str, Any]":
16941735
"""Apply placeholder style conversion using ParameterConverter."""
16951736
target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style

0 commit comments

Comments
 (0)