diff --git a/libs/community/langchain_community/vectorstores/yellowbrick.py b/libs/community/langchain_community/vectorstores/yellowbrick.py index dbf9df917..5f6a406cf 100644 --- a/libs/community/langchain_community/vectorstores/yellowbrick.py +++ b/libs/community/langchain_community/vectorstores/yellowbrick.py @@ -8,6 +8,7 @@ import uuid from contextlib import contextmanager from io import StringIO +from threading import local from typing import ( TYPE_CHECKING, Any, @@ -44,6 +45,7 @@ class IndexType(str, enum.Enum): NONE = "none" LSH = "lsh" + IVF = "ivf" class IndexParams: """Parameters for configuring a Yellowbrick index.""" @@ -98,6 +100,8 @@ def __init__( self.LSH_INDEX_TABLE: str = "_lsh_index" self.LSH_HYPERPLANE_TABLE: str = "_lsh_hyperplane" + self.IVF_INDEX_TABLE: str = "_ivf_index" + self.IVF_CENTROID_TABLE: str = "_ivf_centroid" self.CONTENT_TABLE: str = "_content" self.connection_string = connection_string @@ -126,7 +130,7 @@ def __init__( class DatabaseConnection: _instance = None _connection_string: str - _connection: Optional["PgConnection"] = None + _thread_local = local() # Thread-local storage for connections _logger: logging.Logger def __new__( @@ -139,18 +143,33 @@ def __new__( return cls._instance def close_connection(self) -> None: - if self._connection and not self._connection.closed: - self._connection.close() - self._connection = None + connection = getattr(self._thread_local, "connection", None) + if connection and not connection.closed: + connection.close() + self._thread_local.connection = None def get_connection(self) -> "PgConnection": import psycopg2 - if not self._connection or self._connection.closed: - self._connection = psycopg2.connect(self._connection_string) - self._connection.autocommit = False + connection = getattr(self._thread_local, "connection", None) + try: + if not connection or connection.closed: + connection = psycopg2.connect(self._connection_string) + connection.autocommit = False + self._thread_local.connection = connection + else: + cursor = connection.cursor() + cursor.execute("SELECT 1") + cursor.close() + except (Exception, psycopg2.DatabaseError) as error: + self._logger.error( + f"Error detected, reconnecting: {error}", exc_info=False + ) + connection = psycopg2.connect(self._connection_string) + connection.autocommit = False + self._thread_local.connection = connection - return self._connection + return connection @contextmanager def get_managed_connection(self) -> Generator["PgConnection", None, None]: @@ -209,7 +228,7 @@ def _create_table(self, cursor: "PgCursor") -> None: CREATE TABLE IF NOT EXISTS {t} ( doc_id UUID NOT NULL, text VARCHAR(60000) NOT NULL, - metadata VARCHAR(1024) NOT NULL, + metadata JSONB NOT NULL, CONSTRAINT {c} PRIMARY KEY (doc_id)) DISTRIBUTE ON (doc_id) SORT ON (doc_id) """ @@ -354,7 +373,7 @@ def add_texts( if current_batch_size > 0: self._copy_to_db(cursor, content_io, embeddings_io) - if index_params.index_type == Yellowbrick.IndexType.LSH: + if index_params: self._update_index(index_params, uuid.UUID(doc_uuid)) return results @@ -531,6 +550,11 @@ def similarity_search_with_score_by_vector( from psycopg2.extras import execute_values index_params = kwargs.get("index_params") or Yellowbrick.IndexParams() + where_clause = "1=1" + filter_value = kwargs.get("filter") + if filter_value is not None: + filter_dict = json.loads(filter_value) + where_clause = jsonFilter2sqlWhere(filter_dict, "v3.metadata") with self.connection.get_cursor() as cursor: tmp_embeddings_table = "tmp_" + self._table @@ -555,10 +579,11 @@ def similarity_search_with_score_by_vector( ).format(sql.Identifier(tmp_embeddings_table)) execute_values(cursor, insert_query, data_input) - v1 = sql.Identifier(tmp_embeddings_table) schema_prefix = (self._schema,) if self._schema else () - v2 = sql.Identifier(*schema_prefix, self._table) - v3 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + embeddings_table = sql.Identifier(*schema_prefix, self._table) + content_table = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + ) if index_params.index_type == Yellowbrick.IndexType.LSH: tmp_hash_table = self._table + "_tmp_hash" self._generate_tmp_lsh_hashes( @@ -593,26 +618,28 @@ def similarity_search_with_score_by_vector( (SQRT(SUM(v1.embedding * v1.embedding)) * SQRT(SUM(v2.embedding * v2.embedding))) AS score FROM - {v1} v1 + {tmp_embeddings_table} v1 INNER JOIN - {v2} v2 + {embeddings_table} v2 ON v1.embedding_id = v2.embedding_id INNER JOIN - {v3} v3 + {content_table} v3 ON v2.doc_id = v3.doc_id INNER JOIN index_docs v4 ON v2.doc_id = v4.doc_id + where {where_clause} GROUP BY v3.doc_id, v3.text, v3.metadata ORDER BY score DESC LIMIT %s """ ).format( - v1=v1, - v2=v2, - v3=v3, + tmp_embeddings_table=sql.Identifier(tmp_embeddings_table), + embeddings_table=embeddings_table, + content_table=content_table, lsh_index=lsh_index, input_hash_table=input_hash_table, + where_clause=sql.SQL(where_clause), hamming_distance=sql.Literal( index_params.get_param("hamming_distance", 0) ), @@ -622,6 +649,113 @@ def similarity_search_with_score_by_vector( (k,), ) results = cursor.fetchall() + elif index_params.index_type == Yellowbrick.IndexType.IVF: + centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + ivf_index_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_INDEX_TABLE + ) + quantization = index_params.get_param("quantization", False) + + tmp_quantized_embedding_sql = sql.SQL( + """ + (SELECT doc_id, embedding_id, + ROUND(((embedding + 1) / 2) * 255) AS embedding + FROM {embedding_table}) + """ + ).format( + embedding_table=sql.Identifier(tmp_embeddings_table), + ) + + tmp_embedding_sql = sql.SQL( + """ + (SELECT doc_id, embedding_id, embedding FROM {embedding_table}) + """ + ).format( + embedding_table=sql.Identifier(tmp_embeddings_table), + ) + + if quantization: + tmp_embedding = tmp_quantized_embedding_sql + else: + tmp_embedding = tmp_embedding_sql + + centroid_sql = sql.SQL( + """ + WITH CentroidDistances AS ( + SELECT + e.doc_id AS edoc_id, + c.id AS cdoc_id, + SUM(e.embedding * c.centroid) / + (SQRT(SUM(c.centroid * c.centroid)) * + SQRT(SUM(e.embedding * e.embedding))) AS cosine_similarity + FROM {tmp_embedding} e + JOIN {centroids_table} c ON e.embedding_id = c.centroid_id + GROUP BY edoc_id, cdoc_id + ), + MaxSimilarities AS ( + SELECT + edoc_id, + cdoc_id, + ROW_NUMBER() OVER ( + PARTITION BY edoc_id ORDER BY cosine_similarity DESC + ) AS rank + FROM CentroidDistances + ), + Centroid AS ( + SELECT + cdoc_id + FROM MaxSimilarities + WHERE rank = 1 + ) + """ + ).format( + centroids_table=centroids_table, + tmp_embedding=tmp_embedding_sql, + ) + + sql_query = sql.SQL( + """ + {centroid_sql} + SELECT + text, + metadata, + score + FROM + (SELECT + v5.doc_id doc_id, + SUM(v1.embedding * v5.embedding) / + (SQRT(SUM((v1.embedding * v1.embedding)::float)) * + SQRT(SUM((v5.embedding * v5.embedding)::float))) AS score + FROM + {tmp_embedding} v1 + INNER JOIN + {ivf_index_table} v5 + ON v1.embedding_id = v5.embedding_id + INNER JOIN + Centroid c + ON v5.id = c.cdoc_id + GROUP BY v5.doc_id + ORDER BY score DESC + ) v4 + INNER JOIN + {content_table} v3 + ON v4.doc_id = v3.doc_id + where {where_clause} + ORDER BY score DESC + LIMIT %s + """ + ).format( + centroid_sql=centroid_sql, + content_table=content_table, + ivf_index_table=ivf_index_table, + tmp_embedding=tmp_embedding, + where_clause=sql.SQL(where_clause), + ) + cursor.execute(sql_query, (k,)) + results = cursor.fetchall() + else: sql_query = sql.SQL( """ @@ -636,29 +770,32 @@ def similarity_search_with_score_by_vector( (SQRT(SUM(v1.embedding * v1.embedding)) * SQRT(SUM(v2.embedding * v2.embedding))) AS score FROM - {v1} v1 + {tmp_embeddings_table} v1 INNER JOIN - {v2} v2 + {embeddings_table} v2 ON v1.embedding_id = v2.embedding_id GROUP BY v2.doc_id - ORDER BY score DESC LIMIT %s + ORDER BY score DESC ) v4 INNER JOIN - {v3} v3 + {content_table} v3 ON v4.doc_id = v3.doc_id + where {where_clause} ORDER BY score DESC + LIMIT %s """ ).format( - v1=v1, - v2=v2, - v3=v3, + tmp_embeddings_table=sql.Identifier(tmp_embeddings_table), + embeddings_table=embeddings_table, + content_table=content_table, + where_clause=sql.SQL(where_clause), ) cursor.execute(sql_query, (k,)) results = cursor.fetchall() documents: List[Tuple[Document, float]] = [] for result in results: - metadata = json.loads(result[1]) or {} + metadata = result[1] or {} doc = Document(page_content=result[0], metadata=metadata) documents.append((doc, result[2])) @@ -726,6 +863,292 @@ def similarity_search_by_vector( ) return [doc for doc, _ in documents] + def migrate_schema_v1_to_v2(self) -> None: + from psycopg2 import sql + + try: + with self.connection.get_cursor() as cursor: + schema_prefix = (self._schema,) if self._schema else () + embeddings = sql.Identifier(*schema_prefix, self._table) + # For the RENAME TO statement the destination must be unqualified + # (no schema). + # Build both a schema-qualified identifier for later references and an + # unqualified identifier to use as the target of the RENAME TO. + old_embeddings_qualified = sql.Identifier( + *schema_prefix, self._table + "_v1" + ) + old_embeddings_unqualified = sql.Identifier(self._table + "_v1") + content = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + ) + alter_table_query = sql.SQL("ALTER TABLE {t1} RENAME TO {t2}").format( + t1=embeddings, + # must supply an unqualified name for the RENAME TO target + t2=old_embeddings_unqualified, + ) + cursor.execute(alter_table_query) + + self._create_table(cursor) + + insert_query = sql.SQL( + """ + INSERT INTO {t1} (doc_id, embedding_id, embedding) + SELECT id, embedding_id, embedding FROM {t2} + """ + ).format( + t1=embeddings, + # reference the schema-qualified name when selecting from the old + # table + t2=old_embeddings_qualified, + ) + cursor.execute(insert_query) + + insert_content_query = sql.SQL( + """ + INSERT INTO {t1} (doc_id, text, metadata) + SELECT DISTINCT id, text, metadata FROM {t2} + """ + ).format(t1=content, t2=old_embeddings_qualified) + cursor.execute(insert_content_query) + except Exception as e: + raise RuntimeError(f"Failed to migrate schema: {e}") from e + + def migrate_schema_v2_to_v3(self) -> None: + """Migrate schema from v2 to v3. + + Difference: in the content table the `metadata` column changed from TEXT + to JSONB. + This routine follows the same pattern as `migrate_schema_v1_to_v2`: + - rename the existing embeddings/content tables to a `_v2` suffix (using + an unqualified target for the RENAME TO), + - recreate the v3 tables via `_create_table`, + - copy/convert data into the new tables, casting `metadata` from TEXT -> JSONB. + """ + from psycopg2 import sql + + try: + with self.connection.get_cursor() as cursor: + schema_prefix = (self._schema,) if self._schema else () + + old_content_qualified = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + ) + + alter_table_query = sql.SQL("ALTER TABLE {t1} RENAME TO {t2}").format( + t1=old_content_qualified, + t2=sql.Identifier(self._table + self.CONTENT_TABLE + "_v2"), + ) + try: + self.logger.info( + "Renaming content table %s to %s", + old_content_qualified.as_string(cursor), + (self._table + self.CONTENT_TABLE + "_v2"), + ) + cursor.execute(alter_table_query) + self.logger.debug("ALTER TABLE rename executed successfully") + except Exception as e: + self.logger.exception("ALTER TABLE RENAME TO failed: %s", e) + raise + + try: + old_constraint_name = ( + self._table + self.CONTENT_TABLE + "_pk_doc_id" + ) + new_constraint_name = old_constraint_name + "_v2" + renamed_table_noschema = sql.Identifier( + self._table + self.CONTENT_TABLE + "_v2" + ) + + rename_constraint_sql = sql.SQL( + "ALTER TABLE {t} RENAME CONSTRAINT {old} TO {new}" + ).format( + t=renamed_table_noschema, + old=sql.Identifier(old_constraint_name), + new=sql.Identifier(new_constraint_name), + ) + self.logger.debug( + "Attempting to rename constraint %s -> %s", + old_constraint_name, + new_constraint_name, + ) + cursor.execute(rename_constraint_sql) + self.logger.debug("Constraint rename executed successfully") + except Exception: + self.logger.exception( + "Failed to rename old primary key constraint; " + "continuing migration" + ) + + # Recreate the v3 tables (this will create the content table with + # JSONB metadata) + self._create_table(cursor) + + import json + + from psycopg2.extras import Json, execute_values + + content = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + ) + old_content_v2 = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + "_v2" + ) + + select_sql = sql.SQL("SELECT doc_id, text, metadata FROM {t}").format( + t=old_content_v2 + ) + batch_size = 1000 + named_cursor = cursor.connection.cursor(name="yb_migrate_v2_to_v3") + try: + named_cursor.execute(select_sql.as_string(cursor)) + + while True: + rows = named_cursor.fetchmany(batch_size) + if not rows: + break + + to_insert = [] + for doc_id, text, metadata in rows: + if metadata is None or ( + isinstance(metadata, str) and metadata.strip() == "" + ): + md = {} + else: + if not isinstance(metadata, str): + md = metadata + else: + try: + md = json.loads(metadata) + except Exception: + md = {"old_metadata": metadata} + to_insert.append((doc_id, text, Json(md))) + + if to_insert: + insert_sql = sql.SQL( + "INSERT INTO {t} (doc_id, text, metadata) VALUES %s" + ).format(t=content) + execute_values( + cursor, + insert_sql.as_string(cursor), + to_insert, + page_size=100, + ) + finally: + try: + named_cursor.close() + except Exception: + pass + except Exception as e: + raise RuntimeError(f"Failed to migrate v2 to v3: {e}") from e + + def create_index(self, index_params: Yellowbrick.IndexParams) -> None: + """Create index from existing vectors""" + if index_params.index_type == Yellowbrick.IndexType.LSH: + with self.connection.get_cursor() as cursor: + self._drop_lsh_index_tables(cursor) + self._create_lsh_index_tables(cursor) + self._populate_hyperplanes( + cursor, index_params.get_param("num_hyperplanes", 128) + ) + self._update_lsh_hashes(cursor) + + if index_params.index_type == Yellowbrick.IndexType.IVF: + with self.connection.get_cursor() as cursor: + self._drop_ivf_index_table(cursor) + self._drop_ivf_centroid_tables(cursor) + self._create_ivf_centroid_tables(cursor) + self._create_ivf_index_table(cursor, index_params) + self._populate_centroids( + cursor, index_params.get_param("num_centroids", 40) + ) + self._k_means( + cursor, + num_centroids=index_params.get_param("num_centroids", 40), + max_iter=index_params.get_param("max_iter", 40), + threshold=index_params.get_param("threshold", 1e-4), + ) + self._update_ivf_index(cursor, index_params) + + def drop_index(self, index_params: Yellowbrick.IndexParams) -> None: + """Drop an index""" + if index_params.index_type == Yellowbrick.IndexType.LSH: + with self.connection.get_cursor() as cursor: + self._drop_lsh_index_tables(cursor) + if index_params.index_type == Yellowbrick.IndexType.IVF: + with self.connection.get_cursor() as cursor: + self._drop_ivf_index_table(cursor) + self._drop_ivf_centroid_tables(cursor) + + def _update_index( + self, index_params: Yellowbrick.IndexParams, doc_id: uuid.UUID + ) -> None: + """Update an index with a new or modified embedding in the embeddings table""" + if index_params.index_type == Yellowbrick.IndexType.LSH: + with self.connection.get_cursor() as cursor: + self._update_lsh_hashes(cursor, doc_id) + + if index_params.index_type == Yellowbrick.IndexType.IVF: + with self.connection.get_cursor() as cursor: + self._update_ivf_index(cursor, index_params, doc_id) + + def _create_lsh_index_tables(self, cursor: "PgCursor") -> None: + """Create LSH index and hyperplane tables""" + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + t1 = sql.Identifier(*schema_prefix, self._table + self.LSH_INDEX_TABLE) + t2 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + c1 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_pk_doc_id") + c2 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_fk_doc_id") + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {t1} ( + doc_id UUID NOT NULL, + hash_index SMALLINT NOT NULL, + hash SMALLINT NOT NULL, + CONSTRAINT {c1} PRIMARY KEY (doc_id, hash_index), + CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {t2}(doc_id)) + DISTRIBUTE ON (doc_id) SORT ON (doc_id) + """ + ).format( + t1=t1, + t2=t2, + c1=c1, + c2=c2, + ) + ) + + schema_prefix = (self._schema,) if self._schema else () + t = sql.Identifier(*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE) + c = sql.Identifier(self._table + self.LSH_HYPERPLANE_TABLE + "_pk_id_hp_id") + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {t} ( + id SMALLINT NOT NULL, + hyperplane_id SMALLINT NOT NULL, + hyperplane FLOAT NOT NULL, + CONSTRAINT {c} PRIMARY KEY (id, hyperplane_id)) + DISTRIBUTE REPLICATE SORT ON (id) + """ + ).format( + t=t, + c=c, + ) + ) + + def _drop_lsh_index_tables(self, cursor: "PgCursor") -> None: + """Drop LSH index tables""" + self.drop( + schema=self._schema, table=self._table + self.LSH_INDEX_TABLE, cursor=cursor + ) + self.drop( + schema=self._schema, + table=self._table + self.LSH_HYPERPLANE_TABLE, + cursor=cursor, + ) + def _update_lsh_hashes( self, cursor: "PgCursor", @@ -786,18 +1209,18 @@ def _generate_tmp_lsh_hashes( query_prefix = sql.SQL("CREATE TEMPORARY TABLE {} ON COMMIT DROP AS").format( tmp_hash_table_id ) - group_by = sql.SQL("GROUP BY 1") + group_by = sql.SQL("GROUP BY 1, 2") input_query = sql.SQL( """ {query_prefix} SELECT + e.doc_id, h.id as hash_index, CASE WHEN SUM(e.embedding * h.hyperplane) > 0 THEN 1 ELSE 0 END as hash FROM {embedding_table} e INNER JOIN {hyperplanes} h ON e.embedding_id = h.hyperplane_id {group_by} - DISTRIBUTE REPLICATE """ ).format( query_prefix=query_prefix, @@ -849,125 +1272,689 @@ def _populate_hyperplanes(self, cursor: "PgCursor", num_hyperplanes: int) -> Non ) cursor.execute(insert_query) - def _create_lsh_index_tables(self, cursor: "PgCursor") -> None: - """Create LSH index and hyperplane tables""" + def _create_ivf_index_table( + self, cursor: "PgCursor", index_params: IndexParams + ) -> None: + """Create IVF index and centroid tables""" from psycopg2 import sql schema_prefix = (self._schema,) if self._schema else () - t1 = sql.Identifier(*schema_prefix, self._table + self.LSH_INDEX_TABLE) - t2 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) - c1 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_pk_doc_id") - c2 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_fk_doc_id") - cursor.execute( - sql.SQL( - """ - CREATE TABLE IF NOT EXISTS {t1} ( - doc_id UUID NOT NULL, - hash_index SMALLINT NOT NULL, - hash SMALLINT NOT NULL, - CONSTRAINT {c1} PRIMARY KEY (doc_id, hash_index), - CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {t2}(doc_id)) - DISTRIBUTE ON (doc_id) SORT ON (doc_id) + index_table = sql.Identifier(*schema_prefix, self._table + self.IVF_INDEX_TABLE) + content_table = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + centroid_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + c1 = sql.Identifier(self._table + self.IVF_INDEX_TABLE + "_pk_doc_id") + c2 = sql.Identifier(self._table + self.IVF_INDEX_TABLE + "_fk_doc_id") + + quantization = index_params.get_param("quantization", False) + if quantization: + quantization_sql = sql.SQL("embedding SMALLINT NOT NULL") + else: + quantization_sql = sql.SQL("embedding FLOAT NOT NULL") + + index_table_sql = sql.SQL( """ - ).format( - t1=t1, - t2=t2, - c1=c1, - c2=c2, + DROP TABLE IF EXISTS {index_table}; + CREATE TABLE {index_table} ( + id INT NOT NULL, + doc_id UUID NOT NULL, + embedding_id SMALLINT NOT NULL, + {quantization_sql}, + CONSTRAINT {c1} PRIMARY KEY (id, doc_id), + CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {content_table}(doc_id) ) + DISTRIBUTE ON (doc_id) SORT ON (id) + """ + ).format( + index_table=index_table, + content_table=content_table, + centroid_table=centroid_table, + quantization_sql=quantization_sql, + c1=c1, + c2=c2, ) + cursor.execute(index_table_sql) + + def _drop_ivf_index_table(self, cursor: "PgCursor") -> None: + """Drop IVF index tables""" + self.drop( + schema=self._schema, table=self._table + self.IVF_INDEX_TABLE, cursor=cursor + ) + self.drop( + schema=self._schema, + table=self._table + self.IVF_CENTROID_TABLE, + cursor=cursor, + ) + + def _create_ivf_centroid_tables(self, cursor: "PgCursor") -> None: + from psycopg2 import sql schema_prefix = (self._schema,) if self._schema else () - t = sql.Identifier(*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE) - c = sql.Identifier(self._table + self.LSH_HYPERPLANE_TABLE + "_pk_id_hp_id") - cursor.execute( - sql.SQL( - """ - CREATE TABLE IF NOT EXISTS {t} ( - id SMALLINT NOT NULL, - hyperplane_id SMALLINT NOT NULL, - hyperplane FLOAT NOT NULL, - CONSTRAINT {c} PRIMARY KEY (id, hyperplane_id)) - DISTRIBUTE REPLICATE SORT ON (id) + centroid_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + # content_table = sql.Identifier( # Unused variable + # *schema_prefix, self._table + self.CONTENT_TABLE + # ) + c1 = sql.Identifier(self._table + self.IVF_CENTROID_TABLE + "_pk_doc_id") + + centroid_table_sql = sql.SQL( """ - ).format( - t=t, - c=c, + CREATE TABLE IF NOT EXISTS {centroid_table} ( + id INT NOT NULL, + centroid_id SMALLINT NOT NULL, + centroid FLOAT NOT NULL, + CONSTRAINT {c1} PRIMARY KEY (id, centroid_id) ) + DISTRIBUTE REPLICATE + """ + ).format( + centroid_table=centroid_table, + c1=c1, ) + cursor.execute(centroid_table_sql) - def _drop_lsh_index_tables(self, cursor: "PgCursor") -> None: - """Drop LSH index tables""" + new_centroid_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + "_new" + ) + c1 = sql.Identifier( + self._table + self.IVF_CENTROID_TABLE + "_new" + "_pk_doc_id" + ) + + new_centroid_table_sql = sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {new_centroid_table} ( + id INT NOT NULL, + centroid_id SMALLINT NOT NULL, + centroid FLOAT NOT NULL, + CONSTRAINT {c1} PRIMARY KEY (id, centroid_id) + ) + DISTRIBUTE REPLICATE + """ + ).format( + new_centroid_table=new_centroid_table, + c1=c1, + ) + cursor.execute(new_centroid_table_sql) + + def _drop_ivf_centroid_tables(self, cursor: "PgCursor") -> None: + """Drop IVF centroid tables""" self.drop( - schema=self._schema, table=self._table + self.LSH_INDEX_TABLE, cursor=cursor + schema=self._schema, + table=self._table + self.IVF_CENTROID_TABLE, + cursor=cursor, ) self.drop( schema=self._schema, - table=self._table + self.LSH_HYPERPLANE_TABLE, + table=self._table + self.IVF_CENTROID_TABLE + "_new", cursor=cursor, ) - def create_index(self, index_params: Yellowbrick.IndexParams) -> None: - """Create index from existing vectors""" - if index_params.index_type == Yellowbrick.IndexType.LSH: - with self.connection.get_cursor() as cursor: - self._drop_lsh_index_tables(cursor) - self._create_lsh_index_tables(cursor) - self._populate_hyperplanes( - cursor, index_params.get_param("num_hyperplanes", 128) - ) - self._update_lsh_hashes(cursor) + def _update_centroids(self, cursor: "PgCursor") -> None: + from psycopg2 import sql - def drop_index(self, index_params: Yellowbrick.IndexParams) -> None: - """Drop an index""" - if index_params.index_type == Yellowbrick.IndexType.LSH: - with self.connection.get_cursor() as cursor: - self._drop_lsh_index_tables(cursor) + schema_prefix = (self._schema,) if self._schema else () + embeddings_table = sql.Identifier(*schema_prefix, self._table) + centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + new_centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + "_new" + ) - def _update_index( - self, index_params: Yellowbrick.IndexParams, doc_id: uuid.UUID + self._create_ivf_centroid_tables(cursor) + + update_centroid_sql = sql.SQL( + """ + SET enable_rowpacket_compression_in_distribution=TRUE; + WITH CentroidDistances AS ( + SELECT + e.doc_id AS edoc_id, + c.id AS cdoc_id, + SUM(e.embedding * c.centroid) / + (SQRT(SUM(c.centroid * c.centroid)) * + SQRT(SUM(e.embedding * e.embedding))) AS cosine_similarity + FROM {embeddings_table} e + JOIN {centroids_table} c ON e.embedding_id = c.centroid_id + GROUP BY edoc_id, cdoc_id + ), + MaxSimilarities AS ( + SELECT + edoc_id, + cdoc_id, + ROW_NUMBER() OVER ( + PARTITION BY edoc_id ORDER BY cosine_similarity DESC + ) AS rank + FROM CentroidDistances + ), + AssignedClusters AS ( + SELECT + edoc_id, + cdoc_id + FROM MaxSimilarities + WHERE rank = 1 + ), + ClusterAverages AS ( + SELECT + ac.cdoc_id AS id, + e.embedding_id AS centroid_id, + AVG(e.embedding) AS centroid + FROM AssignedClusters ac + JOIN {embeddings_table} e ON ac.edoc_id = e.doc_id + GROUP BY ac.cdoc_id, e.embedding_id + ) + INSERT INTO {new_centroids_table} + SELECT + ca.id, + ca.centroid_id, + ca.centroid + FROM ClusterAverages ca + ORDER BY 1 ASC + """ + ).format( + centroids_table=centroids_table, + new_centroids_table=new_centroids_table, + embeddings_table=embeddings_table, + ) + cursor.execute(update_centroid_sql) + + def _centroid_shift(self, cursor: "PgCursor") -> float: + from psycopg2 import sql + + max_shift = float("inf") + schema_prefix = (self._schema,) if self._schema else () + centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + centroids_table_noschema = sql.Identifier(self._table + self.IVF_CENTROID_TABLE) + + new_centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + "_new" + ) + centroid_shift_sql = sql.SQL( + """ + WITH CentroidPairs AS ( + SELECT + c1.id AS centroid1, + c2.id AS centroid2, + c1.centroid_id AS dim, + c1.centroid, + c2.centroid, + (c1.centroid - c2.centroid) * (c1.centroid - c2.centroid) AS sq_diff + FROM {centroids_table} c1 + JOIN {new_centroids_table} c2 + ON c1.centroid_id = c2.centroid_id AND c1.id = c2.id + ) + SELECT MAX(euclidean_distance) AS max_shift from ( + SELECT + centroid1, + centroid2, + SQRT(SUM(sq_diff)) AS euclidean_distance + FROM CentroidPairs + GROUP BY 1,2 + ) shifts + """ + ).format( + centroids_table=centroids_table, new_centroids_table=new_centroids_table + ) + cursor.execute(centroid_shift_sql) + max_shift = cursor.fetchone()[0] + + c1 = sql.Identifier( + self._table + self.IVF_CENTROID_TABLE + "_new" + "_pk_doc_id" + ) + c2 = sql.Identifier(self._table + self.IVF_CENTROID_TABLE + "_pk_doc_id") + swap_sql = sql.SQL( + """ + DROP TABLE {centroids_table} CASCADE; + ALTER TABLE {new_centroids_table} DROP CONSTRAINT {c1}; + ALTER TABLE {new_centroids_table} RENAME TO {centroids_table_noschema}; + ALTER TABLE {centroids_table} ADD CONSTRAINT {c2} PRIMARY KEY (id); + """ + ).format( + centroids_table=centroids_table, + new_centroids_table=new_centroids_table, + c1=c1, + c2=c2, + centroids_table_noschema=centroids_table_noschema, + ) + cursor.execute(swap_sql) + + return max_shift + + def _populate_centroids(self, cursor: "PgCursor", num_centroids: int) -> None: + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + + t = sql.Identifier(*schema_prefix, self._table) + cursor.execute(sql.SQL("SELECT MAX(embedding_id) FROM {t}").format(t=t)) + num_dimensions = cursor.fetchone()[0] + num_dimensions += 1 + + centroids_insert_sql = sql.SQL( + """ + WITH parameters AS ( + SELECT {num_centroids} AS num_centroids, + {dims_per_centroid} AS dims_per_centroid + ) + INSERT INTO {centroids_table} (id, centroid_id, centroid) + SELECT id, centroid_id, (random() * 2 - 1) AS centroid + FROM + (SELECT range-1 id FROM sys.rowgenerator + WHERE range BETWEEN 1 AND + (SELECT num_centroids FROM parameters) AND + worker_lid = 0 AND thread_id = 0) a, + (SELECT range-1 centroid_id FROM sys.rowgenerator + WHERE range BETWEEN 1 AND + (SELECT dims_per_centroid FROM parameters) AND + worker_lid = 0 AND thread_id = 0) b + ORDER BY 1 ASC + """ + ).format( + num_centroids=sql.Literal(num_centroids), + dims_per_centroid=sql.Literal(num_dimensions), + centroids_table=centroids_table, + ) + cursor.execute(centroids_insert_sql) + + def _k_means( + self, cursor: "PgCursor", num_centroids: int, max_iter: int = 10, threshold: Optional[float] = 1e-4 ) -> None: - """Update an index with a new or modified embedding in the embeddings table""" - if index_params.index_type == Yellowbrick.IndexType.LSH: - with self.connection.get_cursor() as cursor: - self._update_lsh_hashes(cursor, doc_id) + self._populate_centroids(cursor, num_centroids) - def migrate_schema_v1_to_v2(self) -> None: + for _ in range(max_iter): + self._update_centroids(cursor) + max_shift = self._centroid_shift(cursor) + if threshold is not None and max_shift < threshold: + break + + def _update_ivf_index( + self, + cursor: "PgCursor", + index_params: IndexParams, + doc_id: Optional[uuid.UUID] = None, + ) -> None: from psycopg2 import sql - try: - with self.connection.get_cursor() as cursor: - schema_prefix = (self._schema,) if self._schema else () - embeddings = sql.Identifier(*schema_prefix, self._table) - old_embeddings = sql.Identifier(*schema_prefix, self._table + "_v1") - content = sql.Identifier( - *schema_prefix, self._table + self.CONTENT_TABLE - ) - alter_table_query = sql.SQL("ALTER TABLE {t1} RENAME TO {t2}").format( - t1=embeddings, - t2=old_embeddings, - ) - cursor.execute(alter_table_query) + schema_prefix = (self._schema,) if self._schema else () + embeddings_table = sql.Identifier(*schema_prefix, self._table) + ivf_index_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_INDEX_TABLE + ) + centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + quantization = index_params.get_param("quantization", False) + if quantization: + quantization_sql = sql.SQL( + "ROUND(((e.embedding + 1) / 2) * 255) as embedding" + ) + else: + quantization_sql = sql.SQL("e.embedding") - self._create_table(cursor) + if doc_id: + # Ensure doc_id is safely composed into the SQL using a Literal + where_sql = sql.SQL("WHERE edoc_id = {}") + where_sql = where_sql.format(sql.Literal(str(doc_id))) + else: + where_sql = sql.SQL("WHERE 1=1") + insert_index_sql = sql.SQL( + """ + SET enable_rowpacket_compression_in_distribution=TRUE; + WITH CentroidDistances AS ( + SELECT + e.doc_id AS edoc_id, + c.id AS cdoc_id, + SUM(e.embedding * c.centroid) / + (SQRT(SUM(c.centroid * c.centroid)) * + SQRT(SUM(e.embedding * e.embedding))) AS cosine_similarity + FROM {embeddings_table} e + JOIN {centroids_table} c ON e.embedding_id = c.centroid_id + {where_sql} + GROUP BY edoc_id, cdoc_id + ), + MaxSimilarities AS ( + SELECT + edoc_id, + cdoc_id, + ROW_NUMBER() OVER ( + PARTITION BY edoc_id ORDER BY cosine_similarity DESC + ) AS rank + FROM CentroidDistances + ) + INSERT INTO {ivf_index_table} + SELECT + cdoc_id, + edoc_id, + e.embedding_id, + {quantization_sql} + FROM MaxSimilarities ms + JOIN {embeddings_table} e ON e.doc_id = ms.edoc_id + WHERE ms.rank = 1 + ORDER BY cdoc_id, edoc_id ASC + """ + ).format( + ivf_index_table=ivf_index_table, + embeddings_table=embeddings_table, + centroids_table=centroids_table, + where_sql=where_sql, + quantization_sql=quantization_sql, + ) + cursor.execute(insert_index_sql) - insert_query = sql.SQL( - """ - INSERT INTO {t1} (doc_id, embedding_id, embedding) - SELECT id, embedding_id, embedding FROM {t2} - """ - ).format( - t1=embeddings, - t2=old_embeddings, - ) - cursor.execute(insert_query) + def _find_centroid(self, cursor: "PgCursor", query_embedding_table: str) -> int: + from psycopg2 import sql - insert_content_query = sql.SQL( - """ - INSERT INTO {t1} (doc_id, text, metadata) - SELECT DISTINCT id, text, metadata FROM {t2} - """ - ).format(t1=content, t2=old_embeddings) - cursor.execute(insert_content_query) - except Exception as e: - raise RuntimeError(f"Failed to migrate schema: {e}") from e + schema_prefix = (self._schema,) if self._schema else () + embedding_table = sql.Identifier(query_embedding_table) + ivf_index_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_INDEX_TABLE + ) + centroids_table = sql.Identifier( + *schema_prefix, self._table + self.IVF_CENTROID_TABLE + ) + search_index_sql = sql.SQL( + """ + SET enable_rowpacket_compression_in_distribution=TRUE; + + WITH CentroidDistances AS ( + SELECT + e.doc_id AS edoc_id, + c.id AS cdoc_id, + SUM(e.embedding * c.centroid) / + (SQRT(SUM(c.centroid * c.centroid)) * + SQRT(SUM(e.embedding * e.embedding))) AS cosine_similarity + FROM {embedding_table} e + JOIN {centroids_table} c ON e.embedding_id = c.centroid_id + GROUP BY edoc_id, cdoc_id + ), + MaxSimilarities AS ( + SELECT + edoc_id, + cdoc_id, + ROW_NUMBER() OVER ( + PARTITION BY edoc_id ORDER BY cosine_similarity DESC + ) AS rank + FROM CentroidDistances + ) + SELECT + cdoc_id + FROM MaxSimilarities + WHERE rank = 1 + """ + ).format( + ivf_index_table=ivf_index_table, + embedding_table=embedding_table, + centroids_table=centroids_table, + ) + cursor.execute(search_index_sql) + return cursor.fetchone()[0] + + +def jsonFilter2sqlWhere( + filter_dict: Dict[str, Any], metadata_column: str = "metadata" +) -> str: + """ + Convert Pinecone filter syntax to Yellowbrick SQL WHERE clause using + JSON path syntax. + + Args: + filter_dict: Pinecone-style filter dictionary + metadata_column: Name of the JSONB column containing metadata + (default: "metadata") + + Returns: + SQL WHERE clause string using Yellowbrick JSON path syntax + + Example: + filter = {"genre": {"$eq": "documentary"}, "year": {"$gte": 2020}} + result = jsonFilter2sqlWhere(filter) + # Returns: "(metadata:$.genre::TEXT = 'documentary' AND + # metadata:$.year::INTEGER >= 2020)" + """ + if not filter_dict: + return "1=1" # No filter condition + + return _process_filter_dict(filter_dict, metadata_column) + + +def _process_filter_dict(filter_dict: Dict[str, Any], metadata_column: str) -> str: + """Process a filter dictionary and return SQL WHERE clause. + + Args: + filter_dict: Dictionary containing filter conditions with operators + like $and, $or, or field-specific conditions. + metadata_column: Name of the JSONB column containing metadata. + + Returns: + SQL WHERE clause string with proper parentheses and logical operators. + """ + conditions = [] + + for key, value in filter_dict.items(): + if key == "$and": + and_conditions = [] + for condition in value: + and_conditions.append(_process_filter_dict(condition, metadata_column)) + conditions.append(f"({' AND '.join(and_conditions)})") + + elif key == "$or": + or_conditions = [] + for condition in value: + or_conditions.append(_process_filter_dict(condition, metadata_column)) + conditions.append(f"({' OR '.join(or_conditions)})") + + else: + # Regular field condition + field_condition = _process_field_condition(key, value, metadata_column) + conditions.append(field_condition) + + if len(conditions) == 1: + return conditions[0] + else: + # Multiple conditions at same level are implicitly AND + return f"({' AND '.join(conditions)})" + + +def _process_field_condition( + field_name: str, condition: Any, metadata_column: str +) -> str: + """Process a single field condition. + + Args: + field_name: Name of the metadata field to filter on. + condition: Filter condition which can be a simple value (for equality) + or a dictionary with operators like $eq, $gt, etc. + metadata_column: Name of the JSONB column containing metadata. + + Returns: + SQL condition string for the specified field. + """ + + # Handle simple equality (shorthand syntax) + if not isinstance(condition, dict): + return _create_json_condition(field_name, "$eq", condition, metadata_column) + + # Handle operator-based conditions + conditions = [] + for operator, value in condition.items(): + sql_condition = _create_json_condition( + field_name, operator, value, metadata_column + ) + conditions.append(sql_condition) + + if len(conditions) == 1: + return conditions[0] + else: + # Multiple operators on same field are implicitly AND + return f"({' AND '.join(conditions)})" + + +def _create_json_condition( + field_name: str, operator: str, value: Any, metadata_column: str +) -> str: + """Create a single JSON condition using Yellowbrick JSON path syntax. + + Args: + field_name: Name of the metadata field to filter on. + operator: Filter operator such as $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin, $exists. + value: Value to compare against. Type varies based on operator. + metadata_column: Name of the JSONB column containing metadata. + + Returns: + SQL condition string using Yellowbrick JSON path syntax. + + Raises: + ValueError: If operator is unsupported or value type is incompatible with operator. + """ + + # Escape field name for JSON path if it contains special characters + escaped_field = _escape_json_field_name(field_name) + json_path = f"{metadata_column}:$.{escaped_field}" + + # Determine the cast type for the field + cast_type = ( + f"::{_get_cast_type(value)}" if not isinstance(value, bool) else "::BOOLEAN" + ) + + if operator == "$eq": + return f"{json_path}{cast_type} = {_format_sql_value(value)}" + + elif operator == "$ne": + return f"{json_path}{cast_type} != {_format_sql_value(value)}" + + elif operator == "$gt": + return f"{json_path}{cast_type} > {_format_sql_value(value)}" + + elif operator == "$gte": + return f"{json_path}{cast_type} >= {_format_sql_value(value)}" + + elif operator == "$lt": + return f"{json_path}{cast_type} < {_format_sql_value(value)}" + + elif operator == "$lte": + return f"{json_path}{cast_type} <= {_format_sql_value(value)}" + + elif operator == "$in": + if not isinstance(value, list): + raise ValueError(f"$in operator requires a list, got {type(value)}") + + # Use Yellowbrick's supported IN syntax + formatted_values = ", ".join(_format_sql_value(v) for v in value) + return f"{json_path}{cast_type} IN ({formatted_values})" + + elif operator == "$nin": + if not isinstance(value, list): + raise ValueError(f"$nin operator requires a list, got {type(value)}") + + # For NOT IN operations, convert to AND of != conditions + nin_conditions = [ + f"{json_path}{cast_type} != {_format_sql_value(v)}" for v in value + ] + return f"({' AND '.join(nin_conditions)})" + + elif operator == "$exists": + base_json_path = f"{metadata_column}:$.{escaped_field}" + if value: + return f"{base_json_path} NULL ON ERROR IS NOT NULL" + else: + return f"{base_json_path} NULL ON ERROR IS NULL" + + else: + raise ValueError(f"Unsupported operator: {operator}") + + +def _escape_json_field_name(field_name: str) -> str: + """Escape field names for JSON path expressions in Yellowbrick. + + Uses bracket notation for fields with special characters to ensure + proper JSON path parsing. + + Args: + field_name: The raw field name that may contain special characters. + + Returns: + Escaped field name suitable for use in JSON path expressions. + Uses bracket notation with quotes for fields containing special characters, + or dot notation for simple field names. + """ + # Check if field name contains special characters that need bracket notation + special_chars = [ + ".", + " ", + "'", + '"', + "[", + "]", + ":", + "-", + "+", + "*", + "/", + "\\", + "(", + ")", + "{", + "}", + ] + + if any(char in field_name for char in special_chars): + # Use bracket notation and escape quotes + escaped = field_name.replace("'", "''") + return f"['{escaped}']" + else: + # Use dot notation for simple field names + return field_name + + +def _get_cast_type(value: Any) -> str: + """Determine the appropriate SQL cast type based on Python value type. + + Args: + value: Python value whose type will determine the SQL cast type. + + Returns: + SQL cast type string such as 'INTEGER', 'DOUBLE PRECISION', 'BOOLEAN', or 'TEXT'. + """ + if isinstance(value, int): + return "INTEGER" + elif isinstance(value, float): + return "DOUBLE PRECISION" + elif isinstance(value, bool): + return "BOOLEAN" + elif isinstance(value, str): + return "TEXT" + else: + return "TEXT" # Default to TEXT for other types + + +def _format_sql_value(value: Any) -> str: + """Format a Python value for SQL. + + Args: + value: Python value to be formatted for SQL query inclusion. + Can be None, bool, int, float, str, or other types. + + Returns: + Properly formatted and escaped SQL value string. + None becomes 'NULL', booleans become 'true'/'false', + numbers are converted to strings, and strings are quoted and escaped. + """ + if value is None: + return "NULL" + elif isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, str): + # Escape single quotes by doubling them + escaped = value.replace("'", "''") + return f"'{escaped}'" + else: + # For other types, convert to JSON string + return f"'{json.dumps(value)}'" diff --git a/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py b/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py index bff446885..1130e8307 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py +++ b/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py @@ -1,3 +1,4 @@ +import json import logging from typing import List, Optional @@ -10,10 +11,11 @@ fake_texts, ) -YELLOWBRICK_URL = "postgres://username:password@host:port/database" -YELLOWBRICK_TABLE = "test_table" -YELLOWBRICK_CONTENT = "test_table_content" -YELLOWBRICK_SCHEMA = "test_schema" +YELLOWBRICK_URL = "postgresql://[USERNAME]:[PASSWORD]@[HOSTNAME]:5432/[DATABASE]" +YELLOWBRICK_SCHEMA = "[SCHEMA]" + +YELLOWBRICK_TABLE = "my_embeddings" +YELLOWBRICK_CONTENT = "my_embeddings_content" def _yellowbrick_vector_from_texts( @@ -245,6 +247,107 @@ def test_yellowbrick_with_score() -> None: assert distances[0] > distances[1] > distances[2] +@pytest.mark.requires("yb-vss") +def test_yellowbrick_ivf_search() -> None: + """Test end to end construction and search for IVF index.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.IVF, {"num_centroids": 5, "quantization": False} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + # Ensure index tables are cleaned up + docsearch.drop_index(index_params=index_params) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_ivf_search_update() -> None: + """Test end to end construction and search with updates for IVF index.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.IVF, {"num_centroids": 5, "quantization": False} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + docsearch.add_texts(texts, index_params=index_params) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output == [Document(page_content="oof", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_ivf_delete() -> None: + """Test end to end construction and delete for IVF index.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.IVF, {"num_centroids": 5, "quantization": False} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + added_docs = docsearch.add_texts(texts, index_params=index_params) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output == [Document(page_content="oof", metadata={})] + docsearch.delete(added_docs) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output != [Document(page_content="oof", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_ivf_delete_all() -> None: + """Test end to end construction and delete_all for IVF index.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.IVF, {"num_centroids": 5, "quantization": False} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + docsearch.add_texts(texts, index_params=index_params) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output == [Document(page_content="oof", metadata={})] + docsearch.delete(delete_all=True) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output != [Document(page_content="oof", metadata={})] + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output != [Document(page_content="foo", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) + + @pytest.mark.requires("yb-vss") def test_yellowbrick_add_extra() -> None: """Test end to end construction and MRR search.""" @@ -259,3 +362,140 @@ def test_yellowbrick_add_extra() -> None: docsearch.add_texts(texts, metadatas) output = docsearch.similarity_search("foo", k=10) assert len(output) == 6 + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_add_text_filter() -> None: + """Test adding texts with metadata and filtering via similarity_search filter + argument.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + # Add texts with various metadata for testing different filters + texts = [ + "unique-filter-text-1", + "unique-filter-text-2", + "unique-filter-text-3", + "unique-filter-text-4", + "unique-filter-text-5", + ] + metadatas = [ + {"category": "special", "priority": 1, "tags": ["important", "urgent"]}, + {"category": "normal", "priority": 2, "tags": ["important"]}, + {"category": "special", "priority": 3, "active": True}, + {"category": "normal", "priority": 4, "active": False}, + {"category": "archived", "tags": ["old"]}, + ] + added_ids = docsearch.add_texts(texts, metadatas) + + # Basic search without filter + output = docsearch.similarity_search("unique-filter-text", k=5) + assert len(output) == 5 + + # Test $eq operator + eq_filter = {"category": {"$eq": "special"}} + output_eq = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(eq_filter) + ) + assert len(output_eq) == 2 + assert all(d.page_content.startswith("unique-filter-text") for d in output_eq) + assert all( + i in [1, 3] for i in [int(d.page_content[-1]) for d in output_eq] + ) + + # Test $ne operator + ne_filter = {"category": {"$ne": "special"}} + output_ne = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(ne_filter) + ) + assert len(output_ne) == 3 + assert all( + i in [2, 4, 5] for i in [int(d.page_content[-1]) for d in output_ne] + ) + + # Test $gt operator + gt_filter = {"priority": {"$gt": 2}} + output_gt = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(gt_filter) + ) + assert len(output_gt) == 2 + assert all( + i in [3, 4] for i in [int(d.page_content[-1]) for d in output_gt] + ) + + # Test $in operator + in_filter = {"category": {"$in": ["special", "archived"]}} + output_in = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(in_filter) + ) + assert len(output_in) == 3 + assert all( + i in [1, 3, 5] for i in [int(d.page_content[-1]) for d in output_in] + ) + + # Test $nin operator + nin_filter = {"category": {"$nin": ["normal"]}} + output_nin = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(nin_filter) + ) + assert len(output_nin) == 3 + assert all( + i in [1, 3, 5] for i in [int(d.page_content[-1]) for d in output_nin] + ) + + # Test $exists operator + exists_filter = {"active": {"$exists": True}} + output_exists = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(exists_filter) + ) + assert len(output_exists) == 2 + assert all( + i in [3, 4] for i in [int(d.page_content[-1]) for d in output_exists] + ) + + # Test $and operator + and_filter = {"$and": [{"category": "special"}, {"priority": {"$lt": 3}}]} + output_and = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(and_filter) + ) + assert len(output_and) == 1 + assert output_and[0].page_content == "unique-filter-text-1" + + # Test $or operator + or_filter = {"$or": [{"category": "archived"}, {"priority": 1}]} + output_or = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(or_filter) + ) + assert len(output_or) == 2 + assert all( + i in [1, 5] for i in [int(d.page_content[-1]) for d in output_or] + ) + + # Test nested complex filter + complex_filter = { + "$and": [ + {"$or": [{"category": "special"}, {"category": "normal"}]}, + {"$or": [{"priority": {"$lt": 3}}, {"active": True}]} + ] + } + output_complex = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(complex_filter) + ) + assert len(output_complex) == 3 + assert all( + i in [1, 2, 3] for i in [int(d.page_content[-1]) for d in output_complex] + ) + + # Test empty filter (should be equivalent to no filter) + empty_filter = {} + output_empty = docsearch.similarity_search( + "unique-filter-text", k=5, filter=json.dumps(empty_filter) + ) + assert len(output_empty) == 5 + + # Clean up + docsearch.delete(added_ids) + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) \ No newline at end of file