Skip to content

Commit c7d016c

Browse files
committed
linter fix and add vector index tests
1 parent a54e624 commit c7d016c

12 files changed

+73
-36
lines changed

samples/index_tuning_sample/index_search.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@
3232
vector_table_name,
3333
)
3434
from langchain_google_vertexai import VertexAIEmbeddings
35+
36+
from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore
3537
from langchain_google_alloydb_pg.indexes import (
3638
HNSWIndex,
3739
HNSWQueryOptions,
3840
IVFFlatIndex,
39-
)
40-
41-
from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore
42-
from langchain_google_alloydb_pg.indexes import (
4341
IVFIndex,
4442
ScaNNIndex,
4543
)

src/langchain_google_alloydb_pg/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from langchain_postgres import Column
16+
1517
from .chat_message_history import AlloyDBChatMessageHistory
1618
from .checkpoint import AlloyDBSaver
1719
from .embeddings import AlloyDBEmbeddings
@@ -20,7 +22,6 @@
2022
from .model_manager import AlloyDBModel, AlloyDBModelManager
2123
from .vectorstore import AlloyDBVectorStore
2224
from .version import __version__
23-
from langchain_postgres import Column
2425

2526
__all__ = [
2627
"AlloyDBEngine",

src/langchain_google_alloydb_pg/async_vectorstore.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from langchain_core.documents import Document
2929
from langchain_core.embeddings import Embeddings
3030
from langchain_core.vectorstores import VectorStore, utils
31+
from sqlalchemy import RowMapping, text
32+
from sqlalchemy.ext.asyncio import AsyncEngine
33+
3134
from langchain_google_alloydb_pg.indexes import (
3235
DEFAULT_DISTANCE_STRATEGY,
3336
DEFAULT_INDEX_NAME_SUFFIX,
@@ -36,8 +39,6 @@
3639
ExactNearestNeighbor,
3740
QueryOptions,
3841
)
39-
from sqlalchemy import RowMapping, text
40-
from sqlalchemy.ext.asyncio import AsyncEngine
4142

4243
from .engine import AlloyDBEngine
4344

src/langchain_google_alloydb_pg/indexes.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
from dataclasses import dataclass, field
1717

1818
from langchain_postgres.v2.indexes import (
19-
BaseIndex,
20-
DistanceStrategy,
21-
QueryOptions,
22-
StrategyMixin,
23-
DEFAULT_DISTANCE_STRATEGY,
24-
DEFAULT_INDEX_NAME_SUFFIX,
25-
ExactNearestNeighbor,
26-
HNSWIndex,
27-
HNSWQueryOptions,
28-
IVFFlatIndex,
29-
IVFFlatQueryOptions,
19+
DEFAULT_DISTANCE_STRATEGY,
20+
DEFAULT_INDEX_NAME_SUFFIX,
21+
BaseIndex,
22+
DistanceStrategy,
23+
ExactNearestNeighbor,
24+
HNSWIndex,
25+
HNSWQueryOptions,
26+
IVFFlatIndex,
27+
IVFFlatQueryOptions,
28+
QueryOptions,
29+
StrategyMixin,
3030
)
3131

3232

src/langchain_google_alloydb_pg/vectorstore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from langchain_core.documents import Document
2121
from langchain_core.embeddings import Embeddings
2222
from langchain_core.vectorstores import VectorStore
23+
2324
from langchain_google_alloydb_pg.indexes import (
2425
DEFAULT_DISTANCE_STRATEGY,
2526
BaseIndex,

tests/test_async_vectorstore_index.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
import pytest_asyncio
2222
from langchain_core.documents import Document
2323
from langchain_core.embeddings import DeterministicFakeEmbedding
24+
from sqlalchemy import text
25+
26+
from langchain_google_alloydb_pg import AlloyDBEngine
27+
from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore
2428
from langchain_google_alloydb_pg.indexes import (
2529
DEFAULT_INDEX_NAME_SUFFIX,
2630
DistanceStrategy,
2731
HNSWIndex,
2832
IVFFlatIndex,
2933
)
30-
from sqlalchemy import text
31-
32-
from langchain_google_alloydb_pg import AlloyDBEngine
33-
from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore
3434

3535
DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
3636
DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX

tests/test_async_vectorstore_search.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,15 @@
1919
import pytest_asyncio
2020
from langchain_core.documents import Document
2121
from langchain_core.embeddings import DeterministicFakeEmbedding
22-
from langchain_google_alloydb_pg.indexes import (
23-
DistanceStrategy,
24-
HNSWQueryOptions,
25-
)
2622
from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS
2723
from PIL import Image
2824
from sqlalchemy import text
2925

3026
from langchain_google_alloydb_pg import AlloyDBEngine, Column
3127
from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore
3228
from langchain_google_alloydb_pg.indexes import (
29+
DistanceStrategy,
30+
HNSWQueryOptions,
3331
ScaNNQueryOptions,
3432
)
3533

tests/test_indexes.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
import warnings
1616

17-
from langchain_google_alloydb_pg.indexes import (
17+
from langchain_google_alloydb_pg.indexes import ( # type: ignore
1818
DistanceStrategy,
19+
HNSWIndex,
20+
HNSWQueryOptions,
21+
IVFFlatIndex,
22+
IVFFlatQueryOptions,
1923
IVFIndex,
2024
IVFQueryOptions,
2125
ScaNNIndex,
@@ -44,6 +48,42 @@ def test_distance_strategy(self):
4448
scann_index = ScaNNIndex(distance_strategy=DistanceStrategy.INNER_PRODUCT)
4549
assert scann_index.get_index_function() == "dot_prod"
4650

51+
def test_ivfflat_index(self):
52+
index = IVFFlatIndex(name="test_index", lists=200)
53+
assert index.index_type == "ivfflat"
54+
assert index.lists == 200
55+
assert index.index_options() == "(lists = 200)"
56+
57+
def test_ivfflat_query_options(self):
58+
options = IVFFlatQueryOptions(probes=2)
59+
assert options.to_parameter() == ["ivfflat.probes = 2"]
60+
61+
with warnings.catch_warnings(record=True) as w:
62+
options.to_string()
63+
assert len(w) == 1
64+
assert "to_string is deprecated, use to_parameter instead." in str(
65+
w[-1].message
66+
)
67+
68+
def test_hnsw_index(self):
69+
index = HNSWIndex(name="test_index", m=32, ef_construction=128)
70+
assert index.index_type == "hnsw"
71+
assert index.m == 32
72+
assert index.ef_construction == 128
73+
assert index.index_options() == "(m = 32, ef_construction = 128)"
74+
75+
def test_hnsw_query_options(self):
76+
options = HNSWQueryOptions(ef_search=80)
77+
assert options.to_parameter() == ["hnsw.ef_search = 80"]
78+
79+
with warnings.catch_warnings(record=True) as w:
80+
options.to_string()
81+
82+
assert len(w) == 1
83+
assert "to_string is deprecated, use to_parameter instead." in str(
84+
w[-1].message
85+
)
86+
4787
def test_ivf_index(self):
4888
index = IVFIndex(name="test_index", lists=200)
4989
assert index.index_type == "ivf"

tests/test_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
AlloyDBDocumentSaver,
2626
AlloyDBEngine,
2727
AlloyDBLoader,
28-
Column
28+
Column,
2929
)
3030

3131
project_id = os.environ["PROJECT_ID"]

tests/test_vectorstore_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
import pytest
1919
import pytest_asyncio
2020
from langchain_core.documents import Document
21-
from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions
2221
from sqlalchemy import text
2322

2423
from langchain_google_alloydb_pg import (
2524
AlloyDBEmbeddings,
2625
AlloyDBEngine,
2726
AlloyDBModelManager,
2827
AlloyDBVectorStore,
29-
Column
28+
Column,
3029
)
30+
from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions
3131

3232
DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
3333
DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_")

0 commit comments

Comments
 (0)