11#!/usr/bin/env python3
22
3- import sys
3+ from time import sleep , time
44from typing import List
55
6+ import click
67from elasticsearch import Elasticsearch
78from es_cluster_config import (
89 CLUSTER_URL ,
1314
1415from unstructured .embed .huggingface import HuggingFaceEmbeddingConfig , HuggingFaceEmbeddingEncoder
1516
16- N_ELEMENTS = 1404
17-
1817
1918def embeddings_for_text (text : str ) -> List [float ]:
2019 embedding_encoder = HuggingFaceEmbeddingEncoder (config = HuggingFaceEmbeddingConfig ())
@@ -39,32 +38,96 @@ def query(client: Elasticsearch, search_text: str):
3938 return client .search (index = INDEX_NAME , body = query )
4039
4140
42- if __name__ == "__main__" :
43- print (f"Checking contents of index" f"{ INDEX_NAME } at { CLUSTER_URL } " )
44-
45- print ("Connecting to the Elasticsearch cluster." )
46- client = Elasticsearch (CLUSTER_URL , basic_auth = (USER , PASSWORD ), request_timeout = 30 )
47- print (client .info ())
48-
41+ def validate_count (client : Elasticsearch , num_elements : int ):
42+ print (f"Validating that the count of documents in index { INDEX_NAME } is { num_elements } " )
4943 count = int (client .cat .count (index = INDEX_NAME , format = "json" )[0 ]["count" ])
50- try :
51- assert count == N_ELEMENTS
52- except AssertionError :
53- sys .exit (
54- "Elasticsearch dest check failed:"
55- f"got { count } items in index, expected { N_ELEMENTS } items in index."
56- )
44+ consistent = False
45+ consistent_count = 1
46+ desired_consistent_count = 5
47+ timeout = 60
48+ sleep_interval = 1
49+ start_time = time ()
50+ print (f"initial count returned: { count } " )
51+ while not consistent and time () - start_time < timeout :
52+ new_count = int (client .cat .count (index = INDEX_NAME , format = "json" )[0 ]["count" ])
53+ print (f"latest count returned: { new_count } " )
54+ if new_count == count :
55+ consistent_count += 1
56+ else :
57+ count = new_count
58+ consistent_count = 1
59+ sleep (sleep_interval )
60+ if consistent_count >= desired_consistent_count :
61+ consistent = True
62+ if not consistent :
63+ raise TimeoutError (f"failed to get consistent count after { timeout } s" )
64+ assert count == num_elements , (
65+ f"Elasticsearch dest check failed: got { count } items in index, "
66+ f"expected { num_elements } items in index."
67+ )
5768 print (f"Elasticsearch destination test was successful with { count } items being uploaded." )
5869
70+
71+ def get_embeddings_len (client : Elasticsearch ) -> int :
72+ res = client .search (index = INDEX_NAME , size = 1 , query = {"match_all" : {}})
73+ return len (res ["hits" ]["hits" ][0 ]["_source" ]["embeddings" ])
74+
75+
76+ def validate_embeddings (client : Elasticsearch , embeddings : list [float ]):
5977 # Query the index using the appropriate embedding vector for given query text
6078 # Verify that the top 1 result matches the expected chunk by checking the start text
6179 print ("Testing query to the embedded index." )
62- query_text = (
63- "A gathering of Russian nobility and merchants in historic uniforms, "
64- "discussing the Emperor's manifesto with a mix of solemn anticipation "
65- "and everyday concerns, while Pierre, dressed in a tight nobleman's uniform, "
66- "ponders the French Revolution and social contracts amidst the crowd."
80+ es_embeddings_len = get_embeddings_len (client = client )
81+ assert len (embeddings ) == es_embeddings_len , (
82+ f"length of embeddings ({ len (embeddings )} ) doesn't "
83+ f"match what exists in Elasticsearch ({ es_embeddings_len } )"
6784 )
68- query_response = query (client , query_text )
69- assert query_response ["hits" ]["hits" ][0 ]["_source" ]["text" ].startswith ("CHAPTER XXII" )
85+ query_string = {
86+ "field" : "embeddings" ,
87+ "query_vector" : embeddings ,
88+ "k" : 10 ,
89+ "num_candidates" : 10 ,
90+ }
91+ query_response = client .search (index = INDEX_NAME , knn = query_string )
92+ response_found = query_response ["hits" ]["hits" ][0 ]["_source" ]
93+ assert response_found ["embeddings" ] == embeddings
7094 print ("Query to the embedded index was successful and returned the expected result." )
95+
96+
97+ def validate (num_elements : int , embeddings : list [float ]):
98+ print (f"Checking contents of index" f"{ INDEX_NAME } at { CLUSTER_URL } " )
99+
100+ print ("Connecting to the Elasticsearch cluster." )
101+ client = Elasticsearch (CLUSTER_URL , basic_auth = (USER , PASSWORD ), request_timeout = 30 )
102+ print (client .info ())
103+ validate_count (client = client , num_elements = num_elements )
104+ validate_embeddings (client = client , embeddings = embeddings )
105+
106+
107+ def parse_embeddings (embeddings_str : str ) -> list [float ]:
108+ if embeddings_str .startswith ("[" ):
109+ embeddings_str = embeddings_str [1 :]
110+ if embeddings_str .endswith ("]" ):
111+ embeddings_str = embeddings_str [:- 1 ]
112+ embeddings_split = embeddings_str .split ("," )
113+ embeddings_split = [e .strip () for e in embeddings_split ]
114+ return [float (e ) for e in embeddings_split ]
115+
116+
117+ @click .command ()
118+ @click .option (
119+ "--num-elements" , type = int , required = True , help = "The expected number of elements to exist"
120+ )
121+ @click .option ("--embeddings" , type = str , required = True , help = "List of embeddings to test" )
122+ def run_validation (num_elements : int , embeddings : str ):
123+ try :
124+ parsed_embeddings = parse_embeddings (embeddings_str = embeddings )
125+ except ValueError as e :
126+ raise TypeError (
127+ f"failed to parse embeddings string into list of float: { embeddings } "
128+ ) from e
129+ validate (num_elements = num_elements , embeddings = parsed_embeddings )
130+
131+
132+ if __name__ == "__main__" :
133+ run_validation ()
0 commit comments