Skip to content

Commit 87b4ad5

Browse files
authored
feat: vector search for builder (#270)
Implements a vector search helper in the SQL builder. Generates dialect specific vector searches for Postgres, DuckDB, Oracle, BigQuery, and more.
1 parent ee4ea74 commit 87b4ad5

File tree

17 files changed

+3073
-20
lines changed

17 files changed

+3073
-20
lines changed

AGENTS.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,62 @@ config = AsyncpgConfig(
239239
)
240240
```
241241

242+
### Custom SQLGlot Expression Pattern
243+
244+
For dialect-specific SQL generation (e.g., vector distance functions):
245+
246+
```python
247+
# In sqlspec/builder/_custom_expressions.py
248+
from sqlglot import exp
249+
from typing import Any
250+
251+
class CustomExpression(exp.Expression):
252+
"""Custom expression with dialect-aware SQL generation."""
253+
arg_types = {"this": True, "expression": True, "metric": False}
254+
255+
def sql(self, dialect: "Any | None" = None, **opts: Any) -> str:
256+
"""Override sql() method for dialect-specific generation."""
257+
dialect_name = str(dialect).lower() if dialect else "generic"
258+
259+
left_sql = self.left.sql(dialect=dialect, **opts)
260+
right_sql = self.right.sql(dialect=dialect, **opts)
261+
262+
if dialect_name == "postgres":
263+
return self._sql_postgres(left_sql, right_sql)
264+
if dialect_name == "mysql":
265+
return self._sql_mysql(left_sql, right_sql)
266+
return self._sql_generic(left_sql, right_sql)
267+
268+
# Register with SQLGlot generator system
269+
def _register_with_sqlglot() -> None:
270+
from sqlglot.dialects.postgres import Postgres
271+
from sqlglot.generator import Generator
272+
273+
def custom_sql_base(generator: "Generator", expression: "CustomExpression") -> str:
274+
return expression._sql_generic(generator.sql(expression.left), generator.sql(expression.right))
275+
276+
Generator.TRANSFORMS[CustomExpression] = custom_sql_base
277+
Postgres.Generator.TRANSFORMS[CustomExpression] = custom_sql_postgres
278+
279+
_register_with_sqlglot()
280+
```
281+
282+
**Use this pattern when**:
283+
- Database syntax varies significantly across dialects
284+
- Standard SQLGlot expressions don't match any database's native syntax
285+
- Need operator syntax (e.g., `<->`) vs function calls (e.g., `DISTANCE()`)
286+
287+
**Key principles**:
288+
- Override `.sql()` method for dialect detection
289+
- Register with SQLGlot's TRANSFORMS for nested expression support
290+
- Store metadata (like metric) as `exp.Identifier` in `arg_types` for runtime access
291+
- Provide generic fallback for unsupported dialects
292+
293+
**Example**: `VectorDistance` in `sqlspec/builder/_vector_expressions.py` generates:
294+
- PostgreSQL: `embedding <-> '[0.1,0.2]'` (operator)
295+
- MySQL: `DISTANCE(embedding, STRING_TO_VECTOR('[0.1,0.2]'), 'EUCLIDEAN')` (function)
296+
- Oracle: `VECTOR_DISTANCE(embedding, TO_VECTOR('[0.1,0.2]'), EUCLIDEAN)` (function)
297+
242298
### Error Handling
243299

244300
- Custom exceptions inherit from `SQLSpecError` in `sqlspec/exceptions.py`

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ release: ## Bump version and create re
117117
@make docs
118118
@make clean
119119
@make build
120-
@uv lock --upgrade-package litestar-vite >/dev/null 2>&1
120+
@uv lock --upgrade-package sqlspec >/dev/null 2>&1
121121
@uv run bump-my-version bump $(bump)
122122
@echo "${OK} Release complete 🎉"
123123

docs/reference/builder.rst

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,238 @@ Base Classes
514514
:undoc-members:
515515
:show-inheritance:
516516

517+
Vector Distance Functions
518+
=========================
519+
520+
The query builder provides portable vector similarity search functions that generate dialect-specific SQL across PostgreSQL (pgvector), MySQL 9+, Oracle 23ai+, BigQuery, DuckDB, and other databases.
521+
522+
.. note::
523+
Vector functions are designed for AI/ML similarity search with embedding vectors. The SQL is generated at ``build(dialect=X)`` time, enabling portable query definitions that execute against multiple database types.
524+
525+
Column Methods
526+
--------------
527+
528+
.. py:method:: Column.vector_distance(other_vector, metric="euclidean")
529+
530+
Calculate vector distance using the specified metric.
531+
532+
Generates dialect-specific SQL for vector distance operations.
533+
534+
:param other_vector: Vector to compare against (list, Column reference, or SQLGlot expression)
535+
:type other_vector: list[float] | Column | exp.Expression
536+
:param metric: Distance metric to use (default: "euclidean")
537+
:type metric: str
538+
:return: FunctionColumn expression for use in SELECT, WHERE, ORDER BY
539+
:rtype: FunctionColumn
540+
541+
**Supported Metrics:**
542+
543+
- ``euclidean`` - L2 distance (default)
544+
- ``cosine`` - Cosine distance
545+
- ``inner_product`` - Negative inner product (for similarity ranking)
546+
- ``euclidean_squared`` - L2² distance (Oracle only)
547+
548+
**Examples:**
549+
550+
.. code-block:: python
551+
552+
from sqlspec import sql
553+
from sqlspec.builder import Column
554+
555+
query_vector = [0.1, 0.2, 0.3]
556+
557+
# Basic distance query
558+
query = (
559+
sql.select("id", "title", Column("embedding").vector_distance(query_vector).alias("distance"))
560+
.from_("documents")
561+
.where(Column("embedding").vector_distance(query_vector) < 0.5)
562+
.order_by("distance")
563+
.limit(10)
564+
)
565+
566+
# Using dynamic attribute access
567+
query = (
568+
sql.select("*")
569+
.from_("docs")
570+
.order_by(sql.embedding.vector_distance(query_vector, metric="cosine"))
571+
.limit(10)
572+
)
573+
574+
# Compare two vector columns
575+
query = (
576+
sql.select("*")
577+
.from_("pairs")
578+
.where(Column("vec1").vector_distance(Column("vec2"), metric="euclidean") < 0.3)
579+
)
580+
581+
.. py:method:: Column.cosine_similarity(other_vector)
582+
583+
Calculate cosine similarity (1 - cosine_distance).
584+
585+
Convenience method that computes similarity instead of distance.
586+
Returns values in range [-1, 1] where 1 = identical vectors.
587+
588+
:param other_vector: Vector to compare against
589+
:type other_vector: list[float] | Column | exp.Expression
590+
:return: FunctionColumn expression computing ``1 - cosine_distance(self, other_vector)``
591+
:rtype: FunctionColumn
592+
593+
**Example:**
594+
595+
.. code-block:: python
596+
597+
from sqlspec import sql
598+
599+
query_vector = [0.5, 0.5, 0.5]
600+
601+
# Find most similar documents
602+
query = (
603+
sql.select("id", "title", sql.embedding.cosine_similarity(query_vector).alias("similarity"))
604+
.from_("documents")
605+
.order_by(sql.column("similarity").desc())
606+
.limit(10)
607+
)
608+
609+
Database Compatibility
610+
----------------------
611+
612+
Vector functions generate dialect-specific SQL:
613+
614+
.. list-table::
615+
:header-rows: 1
616+
:widths: 15 25 25 35
617+
618+
* - Database
619+
- Euclidean
620+
- Cosine
621+
- Inner Product
622+
* - PostgreSQL (pgvector)
623+
- ``<->`` operator
624+
- ``<=>`` operator
625+
- ``<#>`` operator
626+
* - MySQL 9+
627+
- ``DISTANCE(..., 'EUCLIDEAN')``
628+
- ``DISTANCE(..., 'COSINE')``
629+
- ``DISTANCE(..., 'DOT')``
630+
* - Oracle 23ai+
631+
- ``VECTOR_DISTANCE(..., EUCLIDEAN)``
632+
- ``VECTOR_DISTANCE(..., COSINE)``
633+
- ``VECTOR_DISTANCE(..., DOT)``
634+
* - BigQuery
635+
- ``EUCLIDEAN_DISTANCE(...)``
636+
- ``COSINE_DISTANCE(...)``
637+
- ``DOT_PRODUCT(...)``
638+
* - DuckDB (VSS extension)
639+
- ``array_distance(...)``
640+
- ``array_cosine_distance(...)``
641+
- ``array_negative_inner_product(...)``
642+
* - Generic
643+
- ``VECTOR_DISTANCE(..., 'EUCLIDEAN')``
644+
- ``VECTOR_DISTANCE(..., 'COSINE')``
645+
- ``VECTOR_DISTANCE(..., 'INNER_PRODUCT')``
646+
647+
Usage Examples
648+
--------------
649+
650+
**Basic Similarity Search**
651+
652+
.. code-block:: python
653+
654+
from sqlspec import sql
655+
656+
# Find documents similar to query vector
657+
query_vector = [0.1, 0.2, 0.3]
658+
659+
query = (
660+
sql.select("id", "title", sql.embedding.vector_distance(query_vector).alias("distance"))
661+
.from_("documents")
662+
.order_by("distance")
663+
.limit(10)
664+
)
665+
666+
# PostgreSQL generates: SELECT id, title, embedding <-> '[0.1,0.2,0.3]' AS distance ...
667+
# MySQL generates: SELECT id, title, DISTANCE(embedding, STRING_TO_VECTOR('[0.1,0.2,0.3]'), 'EUCLIDEAN') AS distance ...
668+
# Oracle generates: SELECT id, title, VECTOR_DISTANCE(embedding, TO_VECTOR('[0.1,0.2,0.3]'), EUCLIDEAN) AS distance ...
669+
670+
**Threshold Filtering**
671+
672+
.. code-block:: python
673+
674+
# Find documents within distance threshold
675+
query = (
676+
sql.select("*")
677+
.from_("documents")
678+
.where(sql.embedding.vector_distance(query_vector, metric="euclidean") < 0.5)
679+
.order_by(sql.embedding.vector_distance(query_vector))
680+
)
681+
682+
**Similarity Ranking**
683+
684+
.. code-block:: python
685+
686+
# Rank by cosine similarity (higher = more similar)
687+
query = (
688+
sql.select("id", "content", sql.embedding.cosine_similarity(query_vector).alias("score"))
689+
.from_("articles")
690+
.order_by(sql.column("score").desc())
691+
.limit(5)
692+
)
693+
694+
**Multiple Metrics**
695+
696+
.. code-block:: python
697+
698+
# Compare different distance metrics in single query
699+
query = (
700+
sql.select(
701+
"id",
702+
sql.embedding.vector_distance(query_vector, metric="euclidean").alias("l2_dist"),
703+
sql.embedding.vector_distance(query_vector, metric="cosine").alias("cos_dist"),
704+
sql.embedding.cosine_similarity(query_vector).alias("similarity")
705+
)
706+
.from_("documents")
707+
.limit(10)
708+
)
709+
710+
**Combined Filters**
711+
712+
.. code-block:: python
713+
714+
# Vector search with additional filters
715+
query = (
716+
sql.select("*")
717+
.from_("products")
718+
.where("category = ?")
719+
.where("in_stock = TRUE")
720+
.where(sql.embedding.vector_distance(query_vector) < 0.3)
721+
.order_by(sql.embedding.vector_distance(query_vector))
722+
.limit(20)
723+
)
724+
725+
Dialect-Agnostic Construction
726+
------------------------------
727+
728+
Queries are constructed once and executed against multiple databases:
729+
730+
.. code-block:: python
731+
732+
from sqlspec import sql
733+
734+
# Define query once
735+
query = (
736+
sql.select("id", "title", sql.embedding.vector_distance([0.1, 0.2, 0.3]).alias("distance"))
737+
.from_("documents")
738+
.order_by("distance")
739+
.limit(10)
740+
)
741+
742+
# Execute with different adapters
743+
pg_result = await pg_session.execute(query) # → PostgreSQL SQL with <-> operator
744+
mysql_result = await mysql_session.execute(query) # → MySQL SQL with DISTANCE()
745+
oracle_result = await oracle_session.execute(query) # → Oracle SQL with VECTOR_DISTANCE()
746+
747+
The dialect is selected at ``build(dialect=X)`` time based on the driver, not at query construction time.
748+
517749
Filter Integration
518750
==================
519751

sqlspec/adapters/adbc/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,10 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
826826
"type_coercion_map": type_map,
827827
}
828828

829+
if detected_dialect == "duckdb":
830+
parameter_overrides["preserve_parameter_format"] = False
831+
parameter_overrides["supported_execution_parameter_styles"] = {ParameterStyle.QMARK, ParameterStyle.NUMERIC}
832+
829833
if detected_dialect in {"postgres", "postgresql"}:
830834
parameter_overrides["ast_transformer"] = build_null_pruning_transform(dialect=sqlglot_dialect)
831835

sqlspec/builder/_base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,25 @@ def _parameterize_expression(self, expression: exp.Expression) -> exp.Expression
319319
A new expression with literals replaced by parameter placeholders
320320
"""
321321

322+
from sqlspec.builder._vector_expressions import VectorDistance
323+
322324
def replacer(node: exp.Expression) -> exp.Expression:
323325
if isinstance(node, exp.Literal):
324326
if node.this in {True, False, None}:
325327
return node
326-
param_name = self._add_parameter(node.this, context="where")
328+
329+
parent = node.parent
330+
if isinstance(parent, exp.Array) and node.find_ancestor(VectorDistance) is not None:
331+
return node
332+
333+
value = node.this
334+
if node.is_number and isinstance(node.this, str):
335+
try:
336+
value = float(node.this) if "." in node.this or "e" in node.this.lower() else int(node.this)
337+
except ValueError:
338+
value = node.this
339+
340+
param_name = self._add_parameter(value, context="where")
327341
return exp.Placeholder(this=param_name)
328342
return node
329343

0 commit comments

Comments
 (0)