Skip to content

Commit fa6826e

Browse files
authored
Fix sqlalchemy warnings when running tests (#733)
This has been bugging me when running my own tests that call langchain methods :P
1 parent bd0bf4e commit fa6826e

File tree

5 files changed

+6
-7
lines changed

5 files changed

+6
-7
lines changed

langchain/cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from sqlalchemy import Column, Integer, String, create_engine, select
66
from sqlalchemy.engine.base import Engine
7-
from sqlalchemy.ext.declarative import declarative_base
8-
from sqlalchemy.orm import Session
7+
from sqlalchemy.orm import Session, declarative_base
98

109
from langchain.schema import Generation
1110

langchain/sql_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def run(self, command: str) -> str:
8686
If the statement returns rows, a string of the results is returned.
8787
If the statement returns no rows, an empty string is returned.
8888
"""
89-
with self._engine.connect() as connection:
89+
with self._engine.begin() as connection:
9090
if self._schema is not None:
9191
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
9292
cursor = connection.exec_driver_sql(command)

tests/unit_tests/llms/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Test base LLM functionality."""
22
from sqlalchemy import Column, Integer, Sequence, String, create_engine
3-
from sqlalchemy.ext.declarative import declarative_base
3+
from sqlalchemy.orm import declarative_base
44

55
import langchain
66
from langchain.cache import InMemoryCache, SQLAlchemyCache

tests/unit_tests/test_sql_database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_sql_database_run() -> None:
4040
engine = create_engine("sqlite:///:memory:")
4141
metadata_obj.create_all(engine)
4242
stmt = insert(user).values(user_id=13, user_name="Harrison")
43-
with engine.connect() as conn:
43+
with engine.begin() as conn:
4444
conn.execute(stmt)
4545
db = SQLDatabase(engine)
4646
command = "select user_name from user where user_id = 13"
@@ -54,7 +54,7 @@ def test_sql_database_run_update() -> None:
5454
engine = create_engine("sqlite:///:memory:")
5555
metadata_obj.create_all(engine)
5656
stmt = insert(user).values(user_id=13, user_name="Harrison")
57-
with engine.connect() as conn:
57+
with engine.begin() as conn:
5858
conn.execute(stmt)
5959
db = SQLDatabase(engine)
6060
command = "update user set user_name='Updated' where user_id = 13"

tests/unit_tests/test_sql_database_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_sql_database_run() -> None:
5757
engine = create_engine("duckdb:///:memory:")
5858
metadata_obj.create_all(engine)
5959
stmt = insert(user).values(user_id=13, user_name="Harrison")
60-
with engine.connect() as conn:
60+
with engine.begin() as conn:
6161
conn.execute(stmt)
6262
db = SQLDatabase(engine, schema="schema_a")
6363
command = 'select user_name from "user" where user_id = 13'

0 commit comments

Comments
 (0)