|
1 | 1 | import json |
2 | 2 | import os |
3 | 3 | import time |
| 4 | +from warnings import warn |
| 5 | + |
| 6 | +from elasticsearch import ( |
| 7 | + ApiError, |
| 8 | + Elasticsearch, |
| 9 | + NotFoundError, |
| 10 | + BadRequestError, |
| 11 | +) |
| 12 | +from elastic_transport._exceptions import ConnectionTimeout |
4 | 13 |
|
5 | | -from elasticsearch import Elasticsearch, NotFoundError |
6 | 14 | from langchain.docstore.document import Document |
7 | 15 | from langchain.text_splitter import RecursiveCharacterTextSplitter |
8 | 16 | from langchain_elasticsearch import ElasticsearchStore |
|
18 | 26 | ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2") |
19 | 27 |
|
20 | 28 | if ELASTICSEARCH_USER: |
21 | | - elasticsearch_client = Elasticsearch( |
| 29 | + es = Elasticsearch( |
22 | 30 | hosts=[ELASTICSEARCH_URL], |
23 | 31 | basic_auth=(ELASTICSEARCH_USER, ELASTICSEARCH_PASSWORD), |
24 | 32 | ) |
25 | 33 | elif ELASTICSEARCH_API_KEY: |
26 | | - elasticsearch_client = Elasticsearch( |
27 | | - hosts=[ELASTICSEARCH_URL], api_key=ELASTICSEARCH_API_KEY |
28 | | - ) |
| 34 | + es = Elasticsearch(hosts=[ELASTICSEARCH_URL], api_key=ELASTICSEARCH_API_KEY) |
29 | 35 | else: |
30 | 36 | raise ValueError( |
31 | 37 | "Please provide either ELASTICSEARCH_USER or ELASTICSEARCH_API_KEY" |
32 | 38 | ) |
33 | 39 |
|
34 | 40 |
|
35 | 41 | def install_elser(): |
| 42 | + # Step 1: Ensure ELSER_MODEL is defined |
36 | 43 | try: |
37 | | - elasticsearch_client.ml.get_trained_models(model_id=ELSER_MODEL) |
38 | | - print(f'"{ELSER_MODEL}" model is available') |
| 44 | + es.ml.get_trained_models(model_id=ELSER_MODEL) |
39 | 45 | except NotFoundError: |
40 | 46 | print(f'"{ELSER_MODEL}" model not available, downloading it now') |
41 | | - elasticsearch_client.ml.put_trained_model( |
| 47 | + es.ml.put_trained_model( |
42 | 48 | model_id=ELSER_MODEL, input={"field_names": ["text_field"]} |
43 | 49 | ) |
44 | | - while True: |
45 | | - status = elasticsearch_client.ml.get_trained_models( |
46 | | - model_id=ELSER_MODEL, include="definition_status" |
47 | | - ) |
48 | | - if status["trained_model_configs"][0]["fully_defined"]: |
49 | | - # model is ready |
50 | | - break |
51 | | - time.sleep(1) |
| 50 | + while True: |
| 51 | + status = es.ml.get_trained_models( |
| 52 | + model_id=ELSER_MODEL, include="definition_status" |
| 53 | + ) |
| 54 | + if status["trained_model_configs"][0]["fully_defined"]: |
| 55 | + break |
| 56 | + time.sleep(1) |
52 | 57 |
|
53 | | - print("Model downloaded, starting deployment") |
54 | | - elasticsearch_client.ml.start_trained_model_deployment( |
| 58 | + # Step 1: Ensure ELSER_MODEL is deployed |
| 59 | + try: |
| 60 | + es.ml.start_trained_model_deployment( |
55 | 61 | model_id=ELSER_MODEL, wait_for="fully_allocated" |
56 | 62 | ) |
| 63 | + print(f'"{ELSER_MODEL}" model is deployed') |
| 64 | + except BadRequestError: |
| 65 | + # This error means it already exists |
| 66 | + pass |
| 67 | + |
| 68 | + print(f'"{ELSER_MODEL}" model is ready') |
57 | 69 |
|
58 | 70 |
|
59 | 71 | def main(): |
@@ -84,19 +96,69 @@ def main(): |
84 | 96 |
|
85 | 97 | print(f"Creating Elasticsearch sparse vector store in {ELASTICSEARCH_URL}") |
86 | 98 |
|
87 | | - elasticsearch_client.indices.delete(index=INDEX, ignore_unavailable=True) |
88 | | - |
89 | | - ElasticsearchStore.from_documents( |
90 | | - docs, |
91 | | - es_connection=elasticsearch_client, |
| 99 | + store = ElasticsearchStore( |
| 100 | + es_connection=es, |
92 | 101 | index_name=INDEX, |
93 | 102 | strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL), |
94 | | - bulk_kwargs={ |
95 | | - "request_timeout": 60, |
96 | | - }, |
97 | 103 | ) |
98 | 104 |
|
| 105 | + # The first call creates ML tasks to support the index, and typically fails |
| 106 | + # with the default 10-second timeout, at least when Elasticsearch is a |
| 107 | + # container running on Apple Silicon. |
| 108 | + # |
| 109 | + # Once elastic/elasticsearch#107077 is fixed, we can use bulk_kwargs to |
| 110 | + # adjust the timeout. |
| 111 | + try: |
| 112 | + es.indices.delete(index=INDEX, ignore_unavailable=True) |
| 113 | + store.add_documents(list(docs)) |
| 114 | + except BadRequestError: |
| 115 | + # This error means the index already exists |
| 116 | + pass |
| 117 | + except (ConnectionTimeout, ApiError) as e: |
| 118 | + if isinstance(e, ApiError) and e.status_code != 408: |
| 119 | + raise |
| 120 | + warn(f"Error occurred, will retry after ML jobs complete: {e}") |
| 121 | + await_ml_tasks() |
| 122 | + es.indices.delete(index=INDEX, ignore_unavailable=True) |
| 123 | + store.add_documents(list(docs)) |
| 124 | + |
| 125 | + |
| 126 | +def await_ml_tasks(max_timeout=600, interval=5): |
| 127 | + """ |
| 128 | + Waits for all machine learning tasks to complete within a specified timeout period. |
| 129 | +
|
| 130 | + Parameters: |
| 131 | + max_timeout (int): Maximum time to wait for tasks to complete, in seconds. |
| 132 | + interval (int): Time to wait between status checks, in seconds. |
| 133 | +
|
| 134 | + Raises: |
| 135 | + TimeoutError: If the timeout is reached and machine learning tasks are still running. |
| 136 | + """ |
| 137 | + start_time = time.time() |
| 138 | + |
| 139 | + tasks = [] # Initialize tasks list |
| 140 | + previous_task_count = 0 # Track the previous number of tasks |
| 141 | + while time.time() - start_time < max_timeout: |
| 142 | + tasks = [] |
| 143 | + resp = es.tasks.list(detailed=True, actions=["cluster:monitor/xpack/ml/*"]) |
| 144 | + for node_id, node_info in resp["nodes"].items(): |
| 145 | + node_tasks = node_info.get("tasks", {}) |
| 146 | + for task_id, task_info in node_tasks.items(): |
| 147 | + tasks.append(task_info["action"]) |
| 148 | + if not tasks: |
| 149 | + break |
| 150 | + current_task_count = len(tasks) |
| 151 | + if current_task_count != previous_task_count: |
| 152 | + warn(f"Awaiting {current_task_count} ML tasks") |
| 153 | + previous_task_count = current_task_count |
| 154 | + time.sleep(interval) |
| 155 | + |
| 156 | + if tasks: |
| 157 | + raise TimeoutError( |
| 158 | + f"Timeout reached. ML tasks are still running: {', '.join(tasks)}" |
| 159 | + ) |
| 160 | + |
99 | 161 |
|
100 | 162 | # Unless we run through flask, we can miss critical settings or telemetry signals. |
101 | 163 | if __name__ == "__main__": |
102 | | - raise RuntimeError("Run via the parent directory: 'flask create-index'") |
| 164 | + main() |
0 commit comments