Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 101 additions & 22 deletions libs/elasticsearch/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion libs/elasticsearch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.3.0"
elasticsearch = {version = "^8.15.1", extras = ["vectorstore_mmr"]}
elasticsearch = {version = "^8.19.1", extras = ["vectorstore_mmr"]}

[tool.poetry.group.test]
optional = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def read_env() -> Dict:

if cloud_id:
return {"es_cloud_id": cloud_id, "es_api_key": api_key}
return {"es_url": url}

result = {"es_url": url}
if api_key:
result["es_api_key"] = api_key
return result


class AsyncRequestSavingTransport(AsyncTransport):
Expand Down Expand Up @@ -46,7 +50,12 @@ def create_es_client(
**es_kwargs,
)

return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs)
client_kwargs: Dict[str, Any] = {"hosts": [es_params["es_url"]]}
if "es_api_key" in es_params:
client_kwargs["api_key"] = es_params["es_api_key"]
client_kwargs.update(es_kwargs)

return AsyncElasticsearch(**client_kwargs)


def requests_saving_es_client() -> AsyncElasticsearch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os

import pytest
from elasticsearch import AsyncElasticsearch

from langchain_elasticsearch.embeddings import AsyncElasticsearchEmbeddings

from ._test_utilities import model_is_deployed
from ._test_utilities import create_es_client, model_is_deployed

# deployed with
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
Expand All @@ -20,7 +19,7 @@
@pytest.mark.asyncio
async def test_elasticsearch_embedding_documents() -> None:
"""Test Elasticsearch embedding documents."""
client = AsyncElasticsearch(hosts=[ES_URL])
client = create_es_client()
if not (await model_is_deployed(client, MODEL_ID)):
await client.close()
pytest.skip(
Expand All @@ -40,7 +39,7 @@ async def test_elasticsearch_embedding_documents() -> None:
@pytest.mark.asyncio
async def test_elasticsearch_embedding_query() -> None:
"""Test Elasticsearch embedding query."""
client = AsyncElasticsearch(hosts=[ES_URL])
client = create_es_client()
if not (await model_is_deployed(client, MODEL_ID)):
await client.close()
pytest.skip(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Test ElasticsearchRetriever functionality."""

import os
import re
import uuid
from typing import Any, Dict
Expand All @@ -11,7 +10,7 @@

from langchain_elasticsearch.retrievers import AsyncElasticsearchRetriever

from ._test_utilities import requests_saving_es_client
from ._test_utilities import read_env, requests_saving_es_client

"""
cd tests/integration_tests
Expand Down Expand Up @@ -88,13 +87,15 @@ async def test_init_url(self, index_name: str) -> None:
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}

es_url = os.environ.get("ES_URL", "http://localhost:9200")
cloud_id = os.environ.get("ES_CLOUD_ID")
api_key = os.environ.get("ES_API_KEY")

config = (
{"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url}
)
env_config = read_env()
# Map test utility format to retriever format
config = {}
if "es_url" in env_config:
config["url"] = env_config["es_url"]
if "es_api_key" in env_config:
config["api_key"] = env_config["es_api_key"]
if "es_cloud_id" in env_config:
config["cloud_id"] = env_config["es_cloud_id"]

retriever = AsyncElasticsearchRetriever.from_es_params(
index_name=index_name,
Expand Down
200 changes: 135 additions & 65 deletions libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,37 +594,50 @@ def assert_query(
query_body: Dict[str, Any], query: Optional[str]
) -> Dict[str, Any]:
assert query_body == {
"knn": {
"field": "vector",
"filter": [],
"k": 1,
"num_candidates": 50,
"query_vector": [
0.06,
0.07,
0.01,
0.08,
0.03,
0.07,
0.09,
0.03,
0.09,
0.09,
0.04,
0.03,
0.08,
0.07,
0.06,
0.08,
],
},
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text": {"query": "foo"}}}],
"retriever": {
"rrf": {
"retrievers": [
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{"match": {"text": {"query": "foo"}}}
],
}
}
}
},
{
"knn": {
"field": "vector",
"filter": [],
"k": 1,
"num_candidates": 50,
"query_vector": [
0.06,
0.07,
0.01,
0.08,
0.03,
0.07,
0.09,
0.03,
0.09,
0.09,
0.04,
0.03,
0.08,
0.07,
0.06,
0.08,
],
}
},
]
}
},
"rank": {"rrf": {}},
}
}
return query_body

Expand Down Expand Up @@ -703,43 +716,100 @@ def assert_query(
query: Optional[str],
rrf: Optional[Union[dict, bool]] = True,
) -> dict:
cmp_query_body = {
"knn": {
"field": "vector",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
0.06,
0.07,
0.01,
0.08,
0.03,
0.07,
0.09,
0.03,
0.09,
0.09,
0.04,
0.03,
0.08,
0.07,
0.06,
0.08,
],
},
"query": {
"bool": {
if rrf is False:
# When rrf=False, uses old format
cmp_query_body = {
"knn": {
"field": "vector",
"filter": [],
"must": [{"match": {"text": {"query": "foo"}}}],
"k": 3,
"num_candidates": 50,
"query_vector": [
0.06,
0.07,
0.01,
0.08,
0.03,
0.07,
0.09,
0.03,
0.09,
0.09,
0.04,
0.03,
0.08,
0.07,
0.06,
0.08,
],
},
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text": {"query": "foo"}}}],
}
},
}
else:
# When rrf=True or rrf=dict, uses new retriever format
rrf_config = {}
if isinstance(rrf, dict):
rrf_config = rrf

cmp_query_body = {
"retriever": {
"rrf": {
# Dictionary unpacking: spreads rrf_config into dict
# If rrf=True: rrf_config={} adds nothing
# If rrf=dict: rrf_config adds custom RRF parameters
**rrf_config,
"retrievers": [
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{
"match": {
"text": {"query": "foo"}
}
}
],
}
}
}
},
{
"knn": {
"field": "vector",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
0.06,
0.07,
0.01,
0.08,
0.03,
0.07,
0.09,
0.03,
0.09,
0.09,
0.04,
0.03,
0.08,
0.07,
0.06,
0.08,
],
}
},
],
}
}
},
}

if isinstance(rrf, dict):
cmp_query_body["rank"] = {"rrf": rrf}
elif isinstance(rrf, bool) and rrf is True:
cmp_query_body["rank"] = {"rrf": {}}
}

assert query_body == cmp_query_body

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def read_env() -> Dict:

if cloud_id:
return {"es_cloud_id": cloud_id, "es_api_key": api_key}
return {"es_url": url}

result = {"es_url": url}
if api_key:
result["es_api_key"] = api_key
return result


class RequestSavingTransport(Transport):
Expand Down Expand Up @@ -46,7 +50,12 @@ def create_es_client(
**es_kwargs,
)

return Elasticsearch(hosts=[es_params["es_url"]], **es_kwargs)
client_kwargs: Dict[str, Any] = {"hosts": [es_params["es_url"]]}
if "es_api_key" in es_params:
client_kwargs["api_key"] = es_params["es_api_key"]
client_kwargs.update(es_kwargs)

return Elasticsearch(**client_kwargs)


def requests_saving_es_client() -> Elasticsearch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os

import pytest
from elasticsearch import Elasticsearch

from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings

from ._test_utilities import model_is_deployed
from ._test_utilities import create_es_client, model_is_deployed

# deployed with
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
Expand All @@ -20,7 +19,7 @@
@pytest.mark.sync
def test_elasticsearch_embedding_documents() -> None:
"""Test Elasticsearch embedding documents."""
client = Elasticsearch(hosts=[ES_URL])
client = create_es_client()
if not (model_is_deployed(client, MODEL_ID)):
client.close()
pytest.skip(
Expand All @@ -40,7 +39,7 @@ def test_elasticsearch_embedding_documents() -> None:
@pytest.mark.sync
def test_elasticsearch_embedding_query() -> None:
"""Test Elasticsearch embedding query."""
client = Elasticsearch(hosts=[ES_URL])
client = create_es_client()
if not (model_is_deployed(client, MODEL_ID)):
client.close()
pytest.skip(
Expand Down
Loading