Skip to content

Commit 75682d3

Browse files
authored
Enhancement: initialize DB2VS with connection_args. (#87)
* Enhancement: initialize DB2VS with connection_args. * Fix format and issues. * Remove output. * Improve the code. * Resolve comments. * Fixed some problems. * Change exception type.
1 parent 1d26929 commit 75682d3

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

libs/langchain-db2/langchain_db2/db2vs.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,39 @@ class DB2VS(VectorStore):
187187

188188
def __init__(
189189
self,
190-
client: Connection,
191190
embedding_function: Union[
192191
Callable[[str], List[float]],
193192
Embeddings,
194193
],
195194
table_name: str,
195+
client: Optional[Connection] = None,
196196
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
197197
query: Optional[str] = "What is a Db2 database",
198198
params: Optional[Dict[str, Any]] = None,
199+
connection_args: Optional[Dict[str, Any]] = None,
199200
):
200-
try:
201+
if client is None:
202+
if connection_args is not None:
203+
database = connection_args.get("database")
204+
host = connection_args.get("host")
205+
port = connection_args.get("port")
206+
username = connection_args.get("username")
207+
password = connection_args.get("password")
208+
209+
conn_str = f"DATABASE={database};hostname={host};port={port};"
210+
f"uid={username};pwd={password};"
211+
212+
if "security" in connection_args:
213+
security = connection_args.get("security")
214+
conn_str += f"security={security};"
215+
216+
self.client = ibm_db_dbi.connect(conn_str, "", "")
217+
else:
218+
raise ValueError("No valid connection or connection_args is passed")
219+
else:
201220
"""Initialize with ibm_db_dbi client."""
202221
self.client = client
222+
try:
203223
"""Initialize with necessary components."""
204224
if not isinstance(embedding_function, Embeddings):
205225
logger.warning(

libs/langchain-db2/tests/integration_tests/test_db2vs.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_add_texts_test() -> None:
245245
{"id": "101", "link": "Document Example Test 2"},
246246
]
247247
model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
248-
vs_obj = DB2VS(connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE)
248+
vs_obj = DB2VS(model, "TB1", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
249249
vs_obj.add_texts(texts, metadata)
250250
drop_table(connection, "TB1")
251251

@@ -256,7 +256,7 @@ def test_add_texts_test() -> None:
256256
{"link": "Document Example Test 2"},
257257
]
258258
model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
259-
vs_obj = DB2VS(connection, model, "TB2", DistanceStrategy.EUCLIDEAN_DISTANCE)
259+
vs_obj = DB2VS(model, "TB2", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
260260
vs_obj.add_texts(texts, metadataNoID)
261261
drop_table(connection, "TB2")
262262

@@ -268,14 +268,14 @@ def test_add_texts_test() -> None:
268268
{"link": "Document Example Test 2"},
269269
{"link": "Document Example Test 3"},
270270
]
271-
vs_obj = DB2VS(connection, model, "TB2", DistanceStrategy.EUCLIDEAN_DISTANCE)
271+
vs_obj = DB2VS(model, "TB2", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
272272
vs_obj.add_texts(texts1, metadataPartialID)
273273
drop_table(connection, "TB2")
274274

275275
# 3. Add record but neither metadata nor ids are there
276276
# Expectation: Successful, new ID will be generated
277277
model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
278-
vs_obj = DB2VS(connection, model, "TB3", DistanceStrategy.EUCLIDEAN_DISTANCE)
278+
vs_obj = DB2VS(model, "TB3", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
279279
texts2 = ["Sam", "John"]
280280
vs_obj.add_texts(texts2)
281281
drop_table(connection, "TB3")
@@ -291,17 +291,17 @@ def test_add_texts_test() -> None:
291291
# Successful
292292
# Successful
293293

294-
vs_obj = DB2VS(connection, model, "TB4", DistanceStrategy.EUCLIDEAN_DISTANCE)
294+
vs_obj = DB2VS(model, "TB4", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
295295
ids4 = ["114", "124"]
296296
vs_obj.add_texts(texts2, ids=ids4)
297297
drop_table(connection, "TB4")
298298

299-
vs_obj = DB2VS(connection, model, "TB5", DistanceStrategy.EUCLIDEAN_DISTANCE)
299+
vs_obj = DB2VS(model, "TB5", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
300300
ids5 = ["", "134"]
301301
vs_obj.add_texts(texts2, ids=ids5)
302302
drop_table(connection, "TB5")
303303

304-
vs_obj = DB2VS(connection, model, "TB6", DistanceStrategy.EUCLIDEAN_DISTANCE)
304+
vs_obj = DB2VS(model, "TB6", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
305305
ids6 = [
306306
"""Good afternoon
307307
my friends""",
@@ -310,15 +310,15 @@ def test_add_texts_test() -> None:
310310
vs_obj.add_texts(texts2, ids=ids6)
311311
drop_table(connection, "TB6")
312312

313-
vs_obj = DB2VS(connection, model, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE)
313+
vs_obj = DB2VS(model, "TB7", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
314314
ids7 = ['"Good afternoon"', '"India"']
315315
vs_obj.add_texts(texts2, ids=ids7)
316316
drop_table(connection, "TB7")
317317

318318
# 5. Add record with ids option but the id are duplicated
319319
# Expectations: SQL0803N having duplicate values for the index key
320320
try:
321-
vs_obj = DB2VS(connection, model, "TB8", DistanceStrategy.EUCLIDEAN_DISTANCE)
321+
vs_obj = DB2VS(model, "TB8", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
322322
ids8 = ["118", "118"]
323323
vs_obj.add_texts(texts2, ids=ids8)
324324
drop_table(connection, "TB8")
@@ -327,7 +327,7 @@ def test_add_texts_test() -> None:
327327

328328
# 6. Add records with both ids and metadatas
329329
# Expectation: Successful, the ID will be generated based on ids
330-
vs_obj = DB2VS(connection, model, "TB9", DistanceStrategy.EUCLIDEAN_DISTANCE)
330+
vs_obj = DB2VS(model, "TB9", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
331331
texts3 = ["Sam 6", "John 6"]
332332
ids9 = ["1", "2"]
333333
metadata = [
@@ -340,7 +340,7 @@ def test_add_texts_test() -> None:
340340
# This one may run slow before using executemany() <<<<<<<<<<
341341
# 7. Add 10000 records
342342
# Expectation:Successful
343-
vs_obj = DB2VS(connection, model, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE)
343+
vs_obj = DB2VS(model, "TB10", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
344344
texts4 = ["Sam{0}".format(i) for i in range(1, 10000)]
345345
ids10 = ["Hello{0}".format(i) for i in range(1, 10000)]
346346
vs_obj.add_texts(texts4, ids=ids10)
@@ -352,7 +352,7 @@ def add(val: str) -> None:
352352
model = HuggingFaceEmbeddings(
353353
model_name="sentence-transformers/all-mpnet-base-v2"
354354
)
355-
vs_obj = DB2VS(connection, model, "TB11", DistanceStrategy.EUCLIDEAN_DISTANCE)
355+
vs_obj = DB2VS(model, "TB11", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
356356
texts5 = [val]
357357
ids11 = texts5
358358
vs_obj.add_texts(texts5, ids=ids11)
@@ -371,7 +371,7 @@ def add1(val: str) -> None:
371371
model = HuggingFaceEmbeddings(
372372
model_name="sentence-transformers/all-mpnet-base-v2"
373373
)
374-
vs_obj = DB2VS(connection, model, "TB12", DistanceStrategy.EUCLIDEAN_DISTANCE)
374+
vs_obj = DB2VS(model, "TB12", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
375375
texts = [val]
376376
ids12 = texts
377377
vs_obj.add_texts(texts, ids=ids12)
@@ -406,7 +406,7 @@ def test_embed_documents_test() -> None:
406406
# 1. Embed String Example-'Sam'
407407
# Expectation: Successful.
408408
model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
409-
vs_obj = DB2VS(connection, model, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE)
409+
vs_obj = DB2VS(model, "TB7", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
410410
vs_obj._embed_documents(
411411
[
412412
"Sam",
@@ -438,7 +438,7 @@ def test_embed_query_test() -> None:
438438
# 1. Embed String
439439
# Expectation: Successful.
440440
model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
441-
vs_obj = DB2VS(connection, model, "TB8", DistanceStrategy.EUCLIDEAN_DISTANCE)
441+
vs_obj = DB2VS(model, "TB8", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
442442
vs_obj._embed_query("Sam")
443443

444444
# 2. Embed Empty string
@@ -466,9 +466,9 @@ def test_perform_search_test() -> None:
466466
model1 = HuggingFaceEmbeddings(
467467
model_name="sentence-transformers/paraphrase-mpnet-base-v2"
468468
)
469-
vs_1 = DB2VS(connection, model1, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE)
470-
vs_2 = DB2VS(connection, model1, "TB11", DistanceStrategy.DOT_PRODUCT)
471-
vs_3 = DB2VS(connection, model1, "TB12", DistanceStrategy.COSINE)
469+
vs_1 = DB2VS(model1, "TB10", connection, DistanceStrategy.EUCLIDEAN_DISTANCE)
470+
vs_2 = DB2VS(model1, "TB11", connection, DistanceStrategy.DOT_PRODUCT)
471+
vs_3 = DB2VS(model1, "TB12", connection, DistanceStrategy.COSINE)
472472

473473
# vector store lists:
474474
vs_list = [vs_1, vs_2, vs_3]
@@ -535,7 +535,7 @@ def test_get_pks() -> None:
535535

536536
table_name = f"Unique_table_{int(time.time())}"
537537

538-
db2vs = DB2VS(client=connection, embedding_function=model, table_name=table_name)
538+
db2vs = DB2VS(embedding_function=model, table_name=table_name, client=connection)
539539
pks = db2vs.get_pks()
540540

541541
assert isinstance(pks, list)

libs/langchain-db2/tests/unit_tests/test_db2vs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_init() -> None:
1010
client = MagicMock()
1111
embedding = DeterministicFakeEmbedding(size=100)
1212
table_name = "foo"
13-
db2vs = DB2VS(client, embedding, table_name)
13+
db2vs = DB2VS(embedding, table_name, client)
1414
assert db2vs is not None
1515
assert isinstance(db2vs, DB2VS)
1616
assert len(client.mock_calls) == 3

0 commit comments

Comments
 (0)