Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ exclude = ["**/node_modules", "**/__pycache__", ".venv", "tools", "docs", "tmp",
include = ["sqlspec", "tests"]
pythonVersion = "3.9"
reportMissingTypeStubs = false
reportPrivateImportUsage = false
reportPrivateUsage = false
reportPrivateImportUsage = true
reportPrivateUsage = true
reportTypedDictNotRequiredAccess = false
reportUnknownArgumentType = false
reportUnnecessaryCast = false
Expand Down
26 changes: 11 additions & 15 deletions sqlspec/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __call__(self, statement: str, dialect: DialectType = None) -> "Any":
actual_type_str == "WITH" and parsed_expr.this and isinstance(parsed_expr.this, exp.Select)
):
builder = Select(dialect=dialect or self.dialect)
builder._expression = parsed_expr
builder.set_expression(parsed_expr)
return builder

if actual_type_str in {"INSERT", "UPDATE", "DELETE"} and parsed_expr.args.get("returning") is not None:
Expand Down Expand Up @@ -451,7 +451,7 @@ def _populate_insert_from_sql(self, builder: "Insert", sql_string: str) -> "Inse
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)

if isinstance(parsed_expr, exp.Insert):
builder._expression = parsed_expr
builder.set_expression(parsed_expr)
return builder

if isinstance(parsed_expr, exp.Select):
Expand All @@ -470,7 +470,7 @@ def _populate_select_from_sql(self, builder: "Select", sql_string: str) -> "Sele
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)

if isinstance(parsed_expr, exp.Select):
builder._expression = parsed_expr
builder.set_expression(parsed_expr)
return builder

logger.warning("Cannot create SELECT from %s statement", type(parsed_expr).__name__)
Expand All @@ -485,7 +485,7 @@ def _populate_update_from_sql(self, builder: "Update", sql_string: str) -> "Upda
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)

if isinstance(parsed_expr, exp.Update):
builder._expression = parsed_expr
builder.set_expression(parsed_expr)
return builder

logger.warning("Cannot create UPDATE from %s statement", type(parsed_expr).__name__)
Expand All @@ -500,7 +500,7 @@ def _populate_delete_from_sql(self, builder: "Delete", sql_string: str) -> "Dele
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)

if isinstance(parsed_expr, exp.Delete):
builder._expression = parsed_expr
builder.set_expression(parsed_expr)
return builder

logger.warning("Cannot create DELETE from %s statement", type(parsed_expr).__name__)
Expand All @@ -515,7 +515,7 @@ def _populate_merge_from_sql(self, builder: "Merge", sql_string: str) -> "Merge"
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)

if isinstance(parsed_expr, exp.Merge):
builder._expression = parsed_expr
builder.set_expression(parsed_expr)
return builder

logger.warning("Cannot create MERGE from %s statement", type(parsed_expr).__name__)
Expand Down Expand Up @@ -724,19 +724,15 @@ def raw(sql_fragment: str, **parameters: Any) -> "Union[exp.Expression, SQL]":
if not parameters:
try:
parsed: exp.Expression = exp.maybe_parse(sql_fragment)
return parsed
if sql_fragment.strip().replace("_", "").replace(".", "").isalnum():
return exp.to_identifier(sql_fragment)
return exp.Literal.string(sql_fragment)
except Exception as e:
msg = f"Failed to parse raw SQL fragment '{sql_fragment}': {e}"
raise SQLBuilderError(msg) from e
return parsed

return SQL(sql_fragment, parameters)

@staticmethod
def count(
column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
self, column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
) -> AggregateExpression:
"""Create a COUNT expression.

Expand All @@ -750,7 +746,7 @@ def count(
if isinstance(column, str) and column == "*":
expr = exp.Count(this=exp.Star(), distinct=distinct)
else:
col_expr = SQLFactory._extract_expression(column)
col_expr = self._extract_expression(column)
expr = exp.Count(this=col_expr, distinct=distinct)
return AggregateExpression(expr)

Expand Down Expand Up @@ -1068,11 +1064,11 @@ def _extract_expression(value: Any) -> exp.Expression:
if isinstance(value, str):
return exp.column(value)
if isinstance(value, Column):
return value._expression
return value.sqlglot_expression
if isinstance(value, ExpressionWrapper):
return value.expression
if isinstance(value, Case):
return exp.Case(ifs=value._conditions, default=value._default)
return exp.Case(ifs=value.conditions, default=value.default)
if isinstance(value, exp.Expression):
return value
return exp.convert(value)
Expand Down
4 changes: 2 additions & 2 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":

try:
if not prepared_parameters:
cursor._rowcount = 0
cursor._rowcount = 0 # pyright: ignore[reportPrivateUsage]
row_count = 0
elif isinstance(prepared_parameters, list) and prepared_parameters:
processed_params = []
Expand Down Expand Up @@ -596,7 +596,7 @@ def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResul
Execution result with statement counts
"""
if statement.is_script:
sql = statement._raw_sql
sql = statement.raw_sql
prepared_parameters: list[Any] = []
else:
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
Expand Down
5 changes: 5 additions & 0 deletions sqlspec/adapters/oracledb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def _execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
msg = "execute_many requires parameters"
raise ValueError(msg)

# Oracle-specific fix: Ensure parameters are in list format for executemany
# Oracle expects a list of sequences, not a tuple of sequences
if isinstance(prepared_parameters, tuple):
prepared_parameters = list(prepared_parameters)

cursor.executemany(sql, prepared_parameters)

# Calculate affected rows based on parameter count
Expand Down
6 changes: 2 additions & 4 deletions sqlspec/adapters/psycopg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def _close_pool(self) -> None:
logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})

try:
if hasattr(self.pool_instance, "_closed"):
self.pool_instance._closed = True
self.pool_instance._closed = True # pyright: ignore[reportPrivateUsage]

self.pool_instance.close()
logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"})
Expand Down Expand Up @@ -350,8 +349,7 @@ async def _close_pool(self) -> None:
return

try:
if hasattr(self.pool_instance, "_closed"):
self.pool_instance._closed = True
self.pool_instance._closed = True # pyright: ignore[reportPrivateUsage]

await self.pool_instance.close()
finally:
Expand Down
7 changes: 3 additions & 4 deletions sqlspec/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
)
from sqlspec.core.cache import (
CacheConfig,
CacheStatsAggregate,
get_cache_config,
get_cache_stats,
get_cache_statistics,
log_cache_stats,
reset_cache_stats,
update_cache_config,
Expand Down Expand Up @@ -532,13 +531,13 @@ def update_cache_config(config: CacheConfig) -> None:
update_cache_config(config)

@staticmethod
def get_cache_stats() -> CacheStatsAggregate:
def get_cache_stats() -> "dict[str, Any]":
"""Get current cache statistics.

Returns:
Cache statistics object with detailed metrics.
"""
return get_cache_stats()
return get_cache_statistics()

@staticmethod
def reset_cache_stats() -> None:
Expand Down
68 changes: 55 additions & 13 deletions sqlspec/builder/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sqlglot.optimizer import optimize
from typing_extensions import Self

from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
from sqlspec.core.cache import get_cache, get_cache_config
from sqlspec.core.hashing import hash_optimized_expression
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
from sqlspec.core.statement import SQL, StatementConfig
Expand Down Expand Up @@ -91,6 +91,36 @@ def _initialize_expression(self) -> None:
"QueryBuilder._create_base_expression must return a valid sqlglot expression."
)

def get_expression(self) -> Optional[exp.Expression]:
"""Get expression reference (no copy).

Returns:
The current SQLGlot expression or None if not set
"""
return self._expression

def set_expression(self, expression: exp.Expression) -> None:
"""Set expression with validation.

Args:
expression: SQLGlot expression to set

Raises:
TypeError: If expression is not a SQLGlot Expression
"""
if not isinstance(expression, exp.Expression):
msg = f"Expected Expression, got {type(expression)}"
raise TypeError(msg)
self._expression = expression

def has_expression(self) -> bool:
"""Check if expression exists.

Returns:
True if expression is set, False otherwise
"""
return self._expression is not None

@abstractmethod
def _create_base_expression(self) -> exp.Expression:
"""Create the base sqlglot expression for the specific query type.
Expand Down Expand Up @@ -307,12 +337,13 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
cte_select_expression: exp.Select

if isinstance(query, QueryBuilder):
if query._expression is None:
query_expr = query.get_expression()
if query_expr is None:
self._raise_sql_builder_error("CTE query builder has no expression.")
if not isinstance(query._expression, exp.Select):
msg = f"CTE query builder expression must be a Select, got {type(query._expression).__name__}."
if not isinstance(query_expr, exp.Select):
msg = f"CTE query builder expression must be a Select, got {type(query_expr).__name__}."
self._raise_sql_builder_error(msg)
cte_select_expression = query._expression
cte_select_expression = query_expr
param_mapping = self._merge_cte_parameters(alias, query.parameters)
updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping)
if not isinstance(updated_expression, exp.Select):
Expand Down Expand Up @@ -398,9 +429,8 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
expression, dialect=dialect_name, schema=self.schema, optimizer_settings=optimizer_settings
)

cache_key_obj = CacheKey((cache_key,))
unified_cache = get_default_cache()
cached_optimized = unified_cache.get(cache_key_obj)
cache = get_cache()
cached_optimized = cache.get("optimized", cache_key)
if cached_optimized:
return cast("exp.Expression", cached_optimized)

Expand All @@ -409,7 +439,7 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings
)

unified_cache.put(cache_key_obj, optimized)
cache.put("optimized", cache_key, optimized)

except Exception:
return expression
Expand All @@ -430,15 +460,14 @@ def to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
return self._to_statement(config)

cache_key_str = self._generate_builder_cache_key(config)
cache_key = CacheKey((cache_key_str,))

unified_cache = get_default_cache()
cached_sql = unified_cache.get(cache_key)
cache = get_cache()
cached_sql = cache.get("builder", cache_key_str)
if cached_sql is not None:
return cast("SQL", cached_sql)

sql_statement = self._to_statement(config)
unified_cache.put(cache_key, sql_statement)
cache.put("builder", cache_key_str, sql_statement)

return sql_statement

Expand Down Expand Up @@ -531,3 +560,16 @@ def _merge_sql_object_parameters(self, sql_obj: Any) -> None:
def parameters(self) -> dict[str, Any]:
"""Public access to query parameters."""
return self._parameters

def set_parameters(self, parameters: dict[str, Any]) -> None:
"""Set query parameters (public API)."""
self._parameters = parameters.copy()

@property
def with_ctes(self) -> "dict[str, exp.CTE]":
"""Get WITH clause CTEs (public API)."""
return dict(self._with_ctes)

def generate_unique_parameter_name(self, base_name: str) -> str:
"""Generate unique parameter name (public API)."""
return self._generate_unique_parameter_name(base_name)
9 changes: 9 additions & 0 deletions sqlspec/builder/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,15 @@ def __hash__(self) -> int:
"""Hash based on table and column name."""
return hash((self.table, self.name))

@property
def sqlglot_expression(self) -> exp.Expression:
"""Get the underlying SQLGlot expression (public API).

Returns:
The SQLGlot expression for this column
"""
return self._expression


class FunctionColumn:
"""Represents the result of a SQL function call on a column."""
Expand Down
14 changes: 7 additions & 7 deletions sqlspec/builder/_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,10 +973,10 @@ def _create_base_expression(self) -> exp.Expression:
select_expr = self._select_query.expression
select_parameters = self._select_query.parameters
elif isinstance(self._select_query, Select):
select_expr = self._select_query._expression
select_parameters = self._select_query._parameters
select_expr = self._select_query.get_expression()
select_parameters = self._select_query.parameters

with_ctes = self._select_query._with_ctes
with_ctes = self._select_query.with_ctes
if with_ctes and select_expr and isinstance(select_expr, exp.Select):
for alias, cte in with_ctes.items():
if has_with_method(select_expr):
Expand Down Expand Up @@ -1100,8 +1100,8 @@ def _create_base_expression(self) -> exp.Expression:
select_expr = self._select_query.expression
select_parameters = self._select_query.parameters
elif isinstance(self._select_query, Select):
select_expr = self._select_query._expression
select_parameters = self._select_query._parameters
select_expr = self._select_query.get_expression()
select_parameters = self._select_query.parameters
elif isinstance(self._select_query, str):
select_expr = exp.maybe_parse(self._select_query)
select_parameters = None
Expand Down Expand Up @@ -1198,8 +1198,8 @@ def _create_base_expression(self) -> exp.Expression:
select_expr = self._select_query.expression
select_parameters = self._select_query.parameters
elif isinstance(self._select_query, Select):
select_expr = self._select_query._expression
select_parameters = self._select_query._parameters
select_expr = self._select_query.get_expression()
select_parameters = self._select_query.parameters
elif isinstance(self._select_query, str):
select_expr = exp.maybe_parse(self._select_query)
select_parameters = None
Expand Down
Loading