|
1 | | -from langchain_elasticsearch import ElasticsearchStore, SparseVectorStrategy |
| 1 | +from langchain_elasticsearch import ElasticsearchStore, DenseVectorStrategy, SparseVectorStrategy |
2 | 2 | from llm_integrations import get_llm |
3 | 3 | from elasticsearch_client import ( |
4 | 4 | elasticsearch_client, |
|
12 | 12 | INDEX_CHAT_HISTORY = os.getenv( |
13 | 13 | "ES_INDEX_CHAT_HISTORY", "workplace-app-docs-chat-history" |
14 | 14 | ) |
15 | | -ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2") |
| 15 | +MODEL_ID = os.getenv("ES_MODEL_ID", ".elser_model_2") |
| 16 | +STRATEGY_TYPE = os.getenv("ES_STRATEGY_TYPE", "sparse") |
| 17 | +VECTOR_FIELD = os.getenv("ES_VECTOR_FIELD", "vector") |
| 18 | +QUERY_FIELD = os.getenv("ES_QUERY_FIELD", "text") |
| 19 | + |
16 | 20 | SESSION_ID_TAG = "[SESSION_ID]" |
17 | 21 | SOURCE_TAG = "[SOURCE]" |
18 | 22 | DONE_TAG = "[DONE]" |
19 | 23 |
|
20 | | -store = ElasticsearchStore( |
21 | | - es_connection=elasticsearch_client, |
22 | | - index_name=INDEX, |
23 | | - strategy=SparseVectorStrategy(model_id=ELSER_MODEL), |
24 | | -) |
| 24 | +if STRATEGY_TYPE == "sparse": |
| 25 | + strategy = SparseVectorStrategy(model_id=MODEL_ID) |
| 26 | + store = ElasticsearchStore( |
| 27 | + es_connection=elasticsearch_client, |
| 28 | + index_name=INDEX, |
| 29 | + strategy=strategy, |
| 30 | + ) |
| 31 | +elif STRATEGY_TYPE == "dense": |
| 32 | + strategy = DenseVectorStrategy(model_id=MODEL_ID, hybrid=True) |
| 33 | + store = ElasticsearchStore( |
| 34 | + es_connection=elasticsearch_client, |
| 35 | + index_name=INDEX, |
| 36 | + vector_query_field=VECTOR_FIELD, |
| 37 | + query_field=QUERY_FIELD, |
| 38 | + strategy=strategy, |
| 39 | + ) |
| 40 | +else: |
| 41 | + raise ValueError(f"Invalid strategy type: {STRATEGY_TYPE}") |
25 | 42 |
|
26 | 43 |
|
27 | 44 | @stream_with_context |
@@ -50,9 +67,10 @@ def ask_question(question, session_id): |
50 | 67 | docs = store.as_retriever().invoke(condensed_question) |
51 | 68 | for doc in docs: |
52 | 69 | doc_source = {**doc.metadata, "page_content": doc.page_content} |
53 | | - current_app.logger.debug( |
54 | | - "Retrieved document passage from: %s", doc.metadata["name"] |
55 | | - ) |
| 70 | + if "name" in doc.metadata: |
| 71 | + current_app.logger.debug( |
| 72 | + "Retrieved document passage from: %s", doc.metadata["name"] |
| 73 | + ) |
56 | 74 | yield f"data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n" |
57 | 75 |
|
58 | 76 | qa_prompt = render_template( |
|
0 commit comments