Skip to content

Commit 66d8689

Browse files
authored
feat: add langchain cassandra tool (#440)
* feat: add langchain cassandra tool * add
1 parent 9d28345 commit 66d8689

File tree

7 files changed

+92
-18
lines changed

7 files changed

+92
-18
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
from e2e_tests.conftest import (
4+
get_vector_store_handler,
5+
)
6+
7+
from e2e_tests.test_utils.vector_store_handler import (
8+
VectorStoreImplementation,
9+
)
10+
11+
12+
@pytest.fixture
13+
def astra_db():
14+
handler = get_vector_store_handler(VectorStoreImplementation.ASTRADB)
15+
context = handler.before_test()
16+
yield context
17+
handler.after_test()
18+
19+
20+
@pytest.fixture
21+
def cassandra():
22+
handler = get_vector_store_handler(VectorStoreImplementation.CASSANDRA)
23+
context = handler.before_test()
24+
yield context
25+
handler.after_test()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import uuid
2+
3+
import cassio
4+
5+
from langchain.agents import AgentExecutor, create_openai_tools_agent
6+
from langchain import hub
7+
from langchain_community.tools.cassandra_database.tool import (
8+
GetSchemaCassandraDatabaseTool,
9+
GetTableDataCassandraDatabaseTool,
10+
QueryCassandraDatabaseTool,
11+
)
12+
from langchain_community.utilities.cassandra_database import CassandraDatabase
13+
from langchain_openai import ChatOpenAI
14+
15+
16+
def test_tool_with_openai_tool(cassandra):
17+
session = cassio.config.resolve_session()
18+
19+
session.execute(
20+
"""
21+
CREATE TABLE IF NOT EXISTS default_keyspace.tool_table_users (
22+
user_id UUID PRIMARY KEY ,
23+
user_name TEXT ,
24+
password TEXT
25+
);
26+
"""
27+
)
28+
session.execute(
29+
"""
30+
CREATE INDEX user_name
31+
ON default_keyspace.tool_table_users (user_name);
32+
"""
33+
)
34+
35+
user_id = uuid.uuid4()
36+
session.execute(
37+
f"""
38+
INSERT INTO default_keyspace.tool_table_users (user_id, user_name)
39+
VALUES ({user_id}, 'my_user');
40+
"""
41+
)
42+
db = CassandraDatabase()
43+
44+
query_tool = QueryCassandraDatabaseTool(db=db)
45+
schema_tool = GetSchemaCassandraDatabaseTool(db=db)
46+
select_data_tool = GetTableDataCassandraDatabaseTool(db=db)
47+
48+
tools = [schema_tool, select_data_tool, query_tool]
49+
50+
model = ChatOpenAI(model="gpt-4o")
51+
52+
prompt = hub.pull("hwchase17/openai-tools-agent")
53+
54+
agent = create_openai_tools_agent(model, tools, prompt)
55+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
56+
response = agent_executor.invoke(
57+
{
58+
"input": "What is the user_id of the user named 'my_user' in table default_keyspace.tool_table_users?"
59+
}
60+
)
61+
print(response)
62+
assert response is not None
63+
assert str(user_id) in str(response)

libs/e2e-tests/e2e_tests/langchain/test_unstructured.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,13 @@
1717
from e2e_tests.conftest import (
1818
set_current_test_info,
1919
get_required_env,
20-
get_vector_store_handler,
2120
)
2221

2322
from e2e_tests.test_utils.vector_store_handler import (
24-
VectorStoreImplementation,
2523
VectorStoreTestContext,
2624
)
2725

2826

29-
@pytest.fixture
30-
def astra_db():
31-
handler = get_vector_store_handler(VectorStoreImplementation.ASTRADB)
32-
context = handler.before_test()
33-
yield context
34-
handler.after_test()
35-
36-
37-
@pytest.fixture
38-
def cassandra():
39-
handler = get_vector_store_handler(VectorStoreImplementation.CASSANDRA)
40-
context = handler.before_test()
41-
yield context
42-
handler.after_test()
43-
44-
4527
@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
4628
@pytest.mark.parametrize("unstructured_mode", ["single", "elements"])
4729
def test_unstructured_api(vector_store, unstructured_mode, request):

libs/e2e-tests/pyproject.langchain.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pillow = "^10.2.0"
1919
python-dotenv = "^1.0.1"
2020
trulens-eval = "^0.21.0"
2121
nemoguardrails = "^0.8.0"
22+
langchainhub = "^0.1.15"
2223

2324
# From LangChain optional deps, needed by WebBaseLoader
2425
beautifulsoup4 = "^4"

libs/e2e-tests/pyproject.llamaindex.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pillow = "^10.2.0"
1717
python-dotenv = "^1.0.1"
1818
trulens-eval = "^0.21.0"
1919
nemoguardrails = "^0.8.0"
20+
langchainhub = "^0.1.15"
2021

2122
# From LangChain optional deps, needed by WebBaseLoader
2223
beautifulsoup4 = "^4"

libs/e2e-tests/pyproject.ragstack-ai.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pillow = "^10.2.0"
2020
python-dotenv = "^1.0.1"
2121
trulens-eval = "^0.21.0"
2222
nemoguardrails = "^0.8.0"
23+
langchainhub = "^0.1.15"
2324

2425
# From LangChain optional deps, needed by WebBaseLoader
2526
beautifulsoup4 = "^4"

libs/e2e-tests/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pillow = "^10.2.0"
1919
python-dotenv = "^1.0.1"
2020
trulens-eval = "^0.21.0"
2121
nemoguardrails = "^0.8.0"
22+
langchainhub = "^0.1.15"
2223

2324
# From LangChain optional deps, needed by WebBaseLoader
2425
beautifulsoup4 = "^4"

0 commit comments

Comments
 (0)