-
Notifications
You must be signed in to change notification settings - Fork 9
enh: enable insertion of documents using map merge sql approach #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -540,6 +540,7 @@ def add_texts( # type: ignore[override] | |
| texts: Iterable[str], | ||
| metadatas: Optional[list[dict]] = None, | ||
| embeddings: Optional[list[list[float]]] = None, | ||
| use_map_merge: bool = False, | ||
| **kwargs: Any, | ||
| ) -> list[str]: | ||
| """Add texts to the vectorstore. | ||
|
|
@@ -558,10 +559,16 @@ def add_texts( # type: ignore[override] | |
| # decide how to add texts | ||
| # using external embedding instance or internal embedding function of HanaDB | ||
| if self.use_internal_embeddings: | ||
| if use_map_merge: | ||
| return self._add_texts_with_map_merge_using_internal_embedding( | ||
| texts, metadatas, embeddings | ||
| ) | ||
| return self._add_texts_using_internal_embedding( | ||
| texts, metadatas, embeddings | ||
| ) | ||
| else: | ||
| if use_map_merge: | ||
| raise ValueError("map merge cannot be used with external embeddings") | ||
| return self._add_texts_using_external_embedding( | ||
| texts, metadatas, embeddings | ||
| ) | ||
|
|
@@ -661,6 +668,140 @@ def _add_texts_using_internal_embedding( | |
| cur.close() | ||
| return [] | ||
|
|
||
| def _add_texts_with_map_merge_using_internal_embedding( | ||
| self, | ||
| texts: Iterable[str], | ||
| metadatas: Optional[list[dict]] = None, | ||
| embeddings: Optional[list[list[float]]] = None, | ||
| **kwargs: Any, | ||
| ) -> list[str]: | ||
| """Add texts with map merge insertion using internal embedding function""" | ||
|
|
||
| cur = self.connection.cursor() | ||
|
|
||
| create_temp_table_sql = f''' | ||
| CREATE TABLE {self.table_name}_TEMP ( | ||
| ID INT PRIMARY KEY, | ||
| "VEC_TEXT" NCLOB, | ||
| "VEC_VECTOR" {self.vector_column_type} | ||
| ) | ||
| ''' | ||
|
|
||
| try: | ||
| cur.execute(create_temp_table_sql) | ||
| except Exception as e: | ||
| raise Exception(f"Error while creating table for map merge :{e}") | ||
|
|
||
| insert_temp_table_sql = f''' | ||
| INSERT INTO {self.table_name}_TEMP (ID, "VEC_TEXT", "VEC_VECTOR") | ||
| VALUES (?,?,NULL) | ||
| ''' | ||
|
|
||
| try: | ||
| cur.executemany(insert_temp_table_sql, [(i, text) for i,text in enumerate(texts)]) | ||
| except Exception as e: | ||
| raise Exception(f"Error while inserting rows for map merge :{e}") | ||
|
|
||
| if self.internal_embedding_remote_source: | ||
| vector_embedding_sql = f"""VECTOR_EMBEDDING(:i_text, 'DOCUMENT', '{self.internal_embedding_model_id}', "{self.internal_embedding_remote_source}")""" | ||
| else: | ||
| vector_embedding_sql = f"""VECTOR_EMBEDDING(:i_text, 'DOCUMENT', '{self.internal_embedding_model_id}')""" | ||
| vector_embedding_sql = self._convert_vector_embedding_to_column_type( | ||
| vector_embedding_sql | ||
| ) | ||
|
|
||
| create_map_merge_function_sql = f""" | ||
| CREATE OR REPLACE FUNCTION F_VECTOR_EMBEDDING( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Append universal unique id to prevent parallel creation conflicts.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create suffix -> hash of all params for parallel sessions with diff embedding models) |
||
| IN i_id INT, | ||
| IN i_text NCLOB | ||
| ) | ||
| RETURNS TABLE("ID" INT, "PAL_EMBEDDING" {self.vector_column_type}) | ||
| LANGUAGE SQLSCRIPT READS SQL DATA AS | ||
| BEGIN | ||
| RETURN | ||
| SELECT :i_id AS "ID", | ||
| {vector_embedding_sql} AS "PAL_EMBEDDING" | ||
| FROM DUMMY; | ||
| END; | ||
| """ | ||
|
|
||
| try: | ||
| cur.execute(create_map_merge_function_sql) | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here (and below). |
||
| raise Exception(f"Error while creating map merge function :{e}") | ||
|
|
||
| call_map_merge_sql = f""" | ||
| DO() | ||
| BEGIN | ||
| dat = SELECT "ID", "VEC_TEXT", "VEC_VECTOR" FROM "{self.table_name}_TEMP"; | ||
| o_res = MAP_MERGE(:dat, "F_VECTOR_EMBEDDING"(:dat."ID", :dat."VEC_TEXT")); | ||
| MERGE INTO "{self.table_name}_TEMP" AS dat | ||
| USING :o_res AS upd | ||
| ON dat."ID" = upd."ID" | ||
| WHEN MATCHED THEN | ||
| UPDATE SET dat."VEC_VECTOR" = upd."PAL_EMBEDDING"; | ||
| END; | ||
| """ | ||
|
|
||
| try: | ||
| cur.execute(call_map_merge_sql) | ||
| except Exception as e: | ||
| raise Exception(f"Error while calling map merge function :{e}") | ||
|
|
||
|
|
||
| fetch_embeddings_sql = f""" | ||
| SELECT VEC_VECTOR FROM {self.table_name}_TEMP | ||
| """ | ||
| try: | ||
| cur.execute(fetch_embeddings_sql) | ||
| rows = cur.fetchall() | ||
| embeddings = [row[0] for row in rows] | ||
| except Exception as e: | ||
| raise Exception(f"Error while fetching embeddings :{e}") | ||
|
|
||
|
|
||
| try: | ||
| cur.execute(f"DROP FUNCTION F_VECTOR_EMBEDDING") | ||
| cur.execute(f"DROP TABLE {self.table_name}_TEMP") | ||
| except Exception as e: | ||
| raise Exception(f"Error while dropping temp table/function :{e}") | ||
|
|
||
|
|
||
| sql_params = [] | ||
| for i, text in enumerate(texts): | ||
| metadata = metadatas[i] if metadatas else {} | ||
| metadata, extracted_special_metadata = self._split_off_special_metadata( | ||
| metadata | ||
| ) | ||
| sql_params.append( | ||
| ( | ||
| text, | ||
| json.dumps( | ||
| HanaDB._sanitize_metadata_keys(metadata) | ||
| ), | ||
| embeddings[i], | ||
| *extracted_special_metadata | ||
| ) | ||
| ) | ||
|
|
||
| specific_metadata_columns_string = self._get_specific_metadata_columns_string() | ||
|
|
||
| sql_str = ( | ||
| f'INSERT INTO "{self.table_name}" ("{self.content_column}", ' | ||
| f'"{self.metadata_column}", ' | ||
| f'"{self.vector_column}"{specific_metadata_columns_string}) ' | ||
| f"VALUES (?, ?, ? " | ||
| f"{(', ?'* len(self.specific_metadata_columns))});" | ||
| ) | ||
|
|
||
| # Insert data into the table | ||
| cur = self.connection.cursor() | ||
| try: | ||
| cur.executemany(sql_str, sql_params) | ||
| finally: | ||
| cur.close() | ||
| return [] | ||
|
|
||
| def _get_specific_metadata_columns_string(self) -> str: | ||
| """ | ||
| Helper function to generate the specific metadata columns as a SQL string. | ||
|
|
@@ -685,6 +826,7 @@ def from_texts( # type: ignore[no-untyped-def, override] | |
| vector_column: str = default_vector_column, | ||
| vector_column_length: int = default_vector_column_length, | ||
| vector_column_type: str = default_vector_column_type, | ||
| use_map_merge: bool = False, | ||
| *, | ||
| specific_metadata_columns: Optional[list[str]] = None, | ||
| ): | ||
|
|
@@ -708,7 +850,7 @@ def from_texts( # type: ignore[no-untyped-def, override] | |
| vector_column_type=vector_column_type, | ||
| specific_metadata_columns=specific_metadata_columns, | ||
| ) | ||
| instance.add_texts(texts, metadatas) | ||
| instance.add_texts(texts, metadatas, use_map_merge=use_map_merge) | ||
| return instance | ||
|
|
||
| def similarity_search( # type: ignore[override] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,6 +120,11 @@ def vectorDB(request): | |
|
|
||
| HanaTestUtils.drop_table(config.conn, HanaTestConstants.TABLE_NAME) | ||
|
|
||
| @pytest.fixture | ||
| def table_name_with_cleanup(): | ||
| yield HanaTestConstants.TABLE_NAME_CUSTOM_DB | ||
| HanaTestUtils.drop_table(config.conn, HanaTestConstants.TABLE_NAME_CUSTOM_DB) | ||
|
|
||
|
|
||
| def test_hanavector_add_texts(vectorDB) -> None: | ||
| vectorDB.add_texts( | ||
|
|
@@ -138,6 +143,72 @@ def test_hanavector_add_texts(vectorDB) -> None: | |
| assert number_of_rows == number_of_texts | ||
|
|
||
|
|
||
| def test_hanavector_add_texts_with_map_merge(vectorDB) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to parameterise these tests? |
||
| vectorDB.add_texts( | ||
| texts=HanaTestConstants.TEXTS, metadatas=HanaTestConstants.METADATAS, use_map_merge=True | ||
| ) | ||
|
|
||
| # check that embeddings have been created in the table | ||
| number_of_texts = len(HanaTestConstants.TEXTS) | ||
| number_of_rows = -1 | ||
| sql_str = f"SELECT COUNT(*) FROM {HanaTestConstants.TABLE_NAME}" | ||
| cur = config.conn.cursor() | ||
| cur.execute(sql_str) | ||
| if cur.has_result_set(): | ||
| rows = cur.fetchall() | ||
| number_of_rows = rows[0][0] | ||
| assert number_of_rows == number_of_texts | ||
|
|
||
|
|
||
| def test_hanavector_from_texts(table_name_with_cleanup) -> None: | ||
| table_name = table_name_with_cleanup | ||
| vectorDB = HanaDB.from_texts( | ||
| connection=config.conn, | ||
| texts=HanaTestConstants.TEXTS, | ||
| embedding=config.embedding, | ||
| table_name=table_name | ||
| ) | ||
|
|
||
| # test if vectorDB is instance of HanaDB | ||
| assert isinstance(vectorDB, HanaDB) | ||
|
|
||
| # check that embeddings have been created in the table | ||
| number_of_texts = len(HanaTestConstants.TEXTS) | ||
| number_of_rows = -1 | ||
| sql_str = f"SELECT COUNT(*) FROM {table_name}" | ||
| cur = config.conn.cursor() | ||
| cur.execute(sql_str) | ||
| if cur.has_result_set(): | ||
| rows = cur.fetchall() | ||
| number_of_rows = rows[0][0] | ||
| assert number_of_rows == number_of_texts | ||
|
|
||
|
|
||
| def test_hanavector_from_texts_with_map_merge(table_name_with_cleanup) -> None: | ||
| table_name = table_name_with_cleanup | ||
| vectorDB = HanaDB.from_texts( | ||
| connection=config.conn, | ||
| texts=HanaTestConstants.TEXTS, | ||
| embedding=config.embedding, | ||
| table_name=table_name, | ||
| use_map_merge=True | ||
| ) | ||
|
|
||
| # test if vectorDB is instance of HanaDB | ||
| assert isinstance(vectorDB, HanaDB) | ||
|
|
||
| # check that embeddings have been created in the table | ||
| number_of_texts = len(HanaTestConstants.TEXTS) | ||
| number_of_rows = -1 | ||
| sql_str = f"SELECT COUNT(*) FROM {table_name}" | ||
| cur = config.conn.cursor() | ||
| cur.execute(sql_str) | ||
| if cur.has_result_set(): | ||
| rows = cur.fetchall() | ||
| number_of_rows = rows[0][0] | ||
| assert number_of_rows == number_of_texts | ||
|
|
||
|
|
||
| def test_hanavector_similarity_search_with_metadata_filter( | ||
| vectorDB, | ||
| ) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we fail here or an exception occurs, the temp table above isn't cleaned up.