Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,62 @@ config = AsyncpgConfig(
)
```

### Custom SQLGlot Expression Pattern

For dialect-specific SQL generation (e.g., vector distance functions):

```python
# In sqlspec/builder/_custom_expressions.py
from sqlglot import exp
from typing import Any

class CustomExpression(exp.Expression):
"""Custom expression with dialect-aware SQL generation."""
arg_types = {"this": True, "expression": True, "metric": False}

def sql(self, dialect: "Any | None" = None, **opts: Any) -> str:
"""Override sql() method for dialect-specific generation."""
dialect_name = str(dialect).lower() if dialect else "generic"

left_sql = self.left.sql(dialect=dialect, **opts)
right_sql = self.right.sql(dialect=dialect, **opts)

if dialect_name == "postgres":
return self._sql_postgres(left_sql, right_sql)
if dialect_name == "mysql":
return self._sql_mysql(left_sql, right_sql)
return self._sql_generic(left_sql, right_sql)

# Register with SQLGlot generator system
def _register_with_sqlglot() -> None:
from sqlglot.dialects.postgres import Postgres
from sqlglot.generator import Generator

def custom_sql_base(generator: "Generator", expression: "CustomExpression") -> str:
return expression._sql_generic(generator.sql(expression.left), generator.sql(expression.right))

Generator.TRANSFORMS[CustomExpression] = custom_sql_base
Postgres.Generator.TRANSFORMS[CustomExpression] = custom_sql_postgres

_register_with_sqlglot()
```

**Use this pattern when**:
- Database syntax varies significantly across dialects
- Standard SQLGlot expressions don't match any database's native syntax
- Need operator syntax (e.g., `<->`) vs function calls (e.g., `DISTANCE()`)

**Key principles**:
- Override `.sql()` method for dialect detection
- Register with SQLGlot's TRANSFORMS for nested expression support
- Store metadata (like metric) as `exp.Identifier` in `arg_types` for runtime access
- Provide generic fallback for unsupported dialects

**Example**: `VectorDistance` in `sqlspec/builder/_vector_expressions.py` generates:
- PostgreSQL: `embedding <-> '[0.1,0.2]'` (operator)
- MySQL: `DISTANCE(embedding, STRING_TO_VECTOR('[0.1,0.2]'), 'EUCLIDEAN')` (function)
- Oracle: `VECTOR_DISTANCE(embedding, TO_VECTOR('[0.1,0.2]'), EUCLIDEAN)` (function)

### Error Handling

- Custom exceptions inherit from `SQLSpecError` in `sqlspec/exceptions.py`
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ release: ## Bump version and create re
@make docs
@make clean
@make build
@uv lock --upgrade-package litestar-vite >/dev/null 2>&1
@uv lock --upgrade-package sqlspec >/dev/null 2>&1
@uv run bump-my-version bump $(bump)
@echo "${OK} Release complete 🎉"

Expand Down
232 changes: 232 additions & 0 deletions docs/reference/builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,238 @@ Base Classes
:undoc-members:
:show-inheritance:

Vector Distance Functions
=========================

The query builder provides portable vector similarity search functions that generate dialect-specific SQL across PostgreSQL (pgvector), MySQL 9+, Oracle 23ai+, BigQuery, DuckDB, and other databases.

.. note::
Vector functions are designed for AI/ML similarity search with embedding vectors. The SQL is generated at ``build(dialect=X)`` time, enabling portable query definitions that execute against multiple database types.

Column Methods
--------------

.. py:method:: Column.vector_distance(other_vector, metric="euclidean")

Calculate vector distance using the specified metric.

Generates dialect-specific SQL for vector distance operations.

:param other_vector: Vector to compare against (list, Column reference, or SQLGlot expression)
:type other_vector: list[float] | Column | exp.Expression
:param metric: Distance metric to use (default: "euclidean")
:type metric: str
:return: FunctionColumn expression for use in SELECT, WHERE, ORDER BY
:rtype: FunctionColumn

**Supported Metrics:**

- ``euclidean`` - L2 distance (default)
- ``cosine`` - Cosine distance
- ``inner_product`` - Negative inner product (for similarity ranking)
- ``euclidean_squared`` - L2² distance (Oracle only)

**Examples:**

.. code-block:: python

from sqlspec import sql
from sqlspec.builder import Column

query_vector = [0.1, 0.2, 0.3]

# Basic distance query
query = (
sql.select("id", "title", Column("embedding").vector_distance(query_vector).alias("distance"))
.from_("documents")
.where(Column("embedding").vector_distance(query_vector) < 0.5)
.order_by("distance")
.limit(10)
)

# Using dynamic attribute access
query = (
sql.select("*")
.from_("docs")
.order_by(sql.embedding.vector_distance(query_vector, metric="cosine"))
.limit(10)
)

# Compare two vector columns
query = (
sql.select("*")
.from_("pairs")
.where(Column("vec1").vector_distance(Column("vec2"), metric="euclidean") < 0.3)
)

.. py:method:: Column.cosine_similarity(other_vector)

Calculate cosine similarity (1 - cosine_distance).

Convenience method that computes similarity instead of distance.
Returns values in range [-1, 1] where 1 = identical vectors.

:param other_vector: Vector to compare against
:type other_vector: list[float] | Column | exp.Expression
:return: FunctionColumn expression computing ``1 - cosine_distance(self, other_vector)``
:rtype: FunctionColumn

**Example:**

.. code-block:: python

from sqlspec import sql

query_vector = [0.5, 0.5, 0.5]

# Find most similar documents
query = (
sql.select("id", "title", sql.embedding.cosine_similarity(query_vector).alias("similarity"))
.from_("documents")
.order_by(sql.column("similarity").desc())
.limit(10)
)

Database Compatibility
----------------------

Vector functions generate dialect-specific SQL:

.. list-table::
:header-rows: 1
:widths: 15 25 25 35

* - Database
- Euclidean
- Cosine
- Inner Product
* - PostgreSQL (pgvector)
- ``<->`` operator
- ``<=>`` operator
- ``<#>`` operator
* - MySQL 9+
- ``DISTANCE(..., 'EUCLIDEAN')``
- ``DISTANCE(..., 'COSINE')``
- ``DISTANCE(..., 'DOT')``
* - Oracle 23ai+
- ``VECTOR_DISTANCE(..., EUCLIDEAN)``
- ``VECTOR_DISTANCE(..., COSINE)``
- ``VECTOR_DISTANCE(..., DOT)``
* - BigQuery
- ``EUCLIDEAN_DISTANCE(...)``
- ``COSINE_DISTANCE(...)``
- ``DOT_PRODUCT(...)``
* - DuckDB (VSS extension)
- ``array_distance(...)``
- ``array_cosine_distance(...)``
- ``array_negative_inner_product(...)``
* - Generic
- ``VECTOR_DISTANCE(..., 'EUCLIDEAN')``
- ``VECTOR_DISTANCE(..., 'COSINE')``
- ``VECTOR_DISTANCE(..., 'INNER_PRODUCT')``

Usage Examples
--------------

**Basic Similarity Search**

.. code-block:: python

from sqlspec import sql

# Find documents similar to query vector
query_vector = [0.1, 0.2, 0.3]

query = (
sql.select("id", "title", sql.embedding.vector_distance(query_vector).alias("distance"))
.from_("documents")
.order_by("distance")
.limit(10)
)

# PostgreSQL generates: SELECT id, title, embedding <-> '[0.1,0.2,0.3]' AS distance ...
# MySQL generates: SELECT id, title, DISTANCE(embedding, STRING_TO_VECTOR('[0.1,0.2,0.3]'), 'EUCLIDEAN') AS distance ...
# Oracle generates: SELECT id, title, VECTOR_DISTANCE(embedding, TO_VECTOR('[0.1,0.2,0.3]'), EUCLIDEAN) AS distance ...

**Threshold Filtering**

.. code-block:: python

# Find documents within distance threshold
query = (
sql.select("*")
.from_("documents")
.where(sql.embedding.vector_distance(query_vector, metric="euclidean") < 0.5)
.order_by(sql.embedding.vector_distance(query_vector))
)

**Similarity Ranking**

.. code-block:: python

# Rank by cosine similarity (higher = more similar)
query = (
sql.select("id", "content", sql.embedding.cosine_similarity(query_vector).alias("score"))
.from_("articles")
.order_by(sql.column("score").desc())
.limit(5)
)

**Multiple Metrics**

.. code-block:: python

# Compare different distance metrics in single query
query = (
sql.select(
"id",
sql.embedding.vector_distance(query_vector, metric="euclidean").alias("l2_dist"),
sql.embedding.vector_distance(query_vector, metric="cosine").alias("cos_dist"),
sql.embedding.cosine_similarity(query_vector).alias("similarity")
)
.from_("documents")
.limit(10)
)

**Combined Filters**

.. code-block:: python

# Vector search with additional filters
query = (
sql.select("*")
.from_("products")
.where("category = ?")
.where("in_stock = TRUE")
.where(sql.embedding.vector_distance(query_vector) < 0.3)
.order_by(sql.embedding.vector_distance(query_vector))
.limit(20)
)

Dialect-Agnostic Construction
------------------------------

Queries are constructed once and executed against multiple databases:

.. code-block:: python

from sqlspec import sql

# Define query once
query = (
sql.select("id", "title", sql.embedding.vector_distance([0.1, 0.2, 0.3]).alias("distance"))
.from_("documents")
.order_by("distance")
.limit(10)
)

# Execute with different adapters
pg_result = await pg_session.execute(query) # → PostgreSQL SQL with <-> operator
mysql_result = await mysql_session.execute(query) # → MySQL SQL with DISTANCE()
oracle_result = await oracle_session.execute(query) # → Oracle SQL with VECTOR_DISTANCE()

The dialect is selected at ``build(dialect=X)`` time based on the driver, not at query construction time.

Filter Integration
==================

Expand Down
4 changes: 4 additions & 0 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
"type_coercion_map": type_map,
}

if detected_dialect == "duckdb":
parameter_overrides["preserve_parameter_format"] = False
parameter_overrides["supported_execution_parameter_styles"] = {ParameterStyle.QMARK, ParameterStyle.NUMERIC}

if detected_dialect in {"postgres", "postgresql"}:
parameter_overrides["ast_transformer"] = build_null_pruning_transform(dialect=sqlglot_dialect)

Expand Down
16 changes: 15 additions & 1 deletion sqlspec/builder/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,25 @@ def _parameterize_expression(self, expression: exp.Expression) -> exp.Expression
A new expression with literals replaced by parameter placeholders
"""

from sqlspec.builder._vector_expressions import VectorDistance

def replacer(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Literal):
if node.this in {True, False, None}:
return node
param_name = self._add_parameter(node.this, context="where")

parent = node.parent
if isinstance(parent, exp.Array) and node.find_ancestor(VectorDistance) is not None:
return node

value = node.this
if node.is_number and isinstance(node.this, str):
try:
value = float(node.this) if "." in node.this or "e" in node.this.lower() else int(node.this)
except ValueError:
value = node.this

param_name = self._add_parameter(value, context="where")
return exp.Placeholder(this=param_name)
return node

Expand Down
Loading