Skip to content

Commit 14db57b

Browse files
feat[langchain-db2]: Added get_pks method, added clean_table function, update patch version (#84)
* feat: Added get_pks method, added clean_table function, update patch version * Use TRUNCATE command instead of DELETE
1 parent 6980087 commit 14db57b

File tree

4 files changed

+816
-614
lines changed

4 files changed

+816
-614
lines changed

libs/langchain-db2/langchain_db2/db2vs.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,33 @@ def drop_table(client: Connection, table_name: str) -> None:
150150
return
151151

152152

153+
@_handle_exceptions
154+
def clear_table(client: Connection, table_name: str) -> None:
155+
"""Remove all records from the table using TRUNCATE.
156+
157+
Args:
158+
client: The ibm_db_dbi connection object.
159+
table_name: The name of the table to clear.
160+
"""
161+
if not _table_exists(client, table_name):
162+
logger.info(f"Table {table_name} not found…")
163+
return
164+
165+
cursor = client.cursor()
166+
ddl = f"TRUNCATE TABLE {table_name} IMMEDIATE"
167+
try:
168+
client.commit()
169+
cursor.execute(ddl)
170+
client.commit()
171+
logger.info(f"Table {table_name} cleared successfully.")
172+
except Exception:
173+
client.rollback()
174+
logger.exception(f"Failed to clear table {table_name}. Rolled back.")
175+
raise
176+
finally:
177+
cursor.close()
178+
179+
153180
class DB2VS(VectorStore):
154181
"""`DB2VS` vector store.
155182
@@ -711,3 +738,28 @@ def from_texts(
711738
)
712739
vss.add_texts(texts=list(texts), metadatas=metadatas)
713740
return vss
741+
742+
@_handle_exceptions
743+
def get_pks(self, expr: Optional[str] = None) -> List[str]:
744+
"""Get primary keys, optionally filtered by expr.
745+
746+
Args:
747+
expr: SQL boolean expression to filter rows, e.g.:
748+
"id IN ('ABC123','DEF456')" or "title LIKE 'Abc%'".
749+
If None, returns all rows.
750+
Returns:
751+
List[str]: List of matching primary-key values.
752+
"""
753+
sql = f"SELECT id FROM {self.table_name}"
754+
755+
if expr:
756+
sql += f" WHERE {expr}"
757+
758+
cursor = self.client.cursor()
759+
try:
760+
cursor.execute(sql)
761+
rows = cursor.fetchall()
762+
finally:
763+
cursor.close()
764+
765+
return [row[0] for row in rows]

0 commit comments

Comments
 (0)