1212from datetime import datetime , timezone
1313
1414
15- async def run_ai_search_query (
16- query ,
17- vector_fields : list [str ],
18- retrieval_fields : list [str ],
19- index_name : str ,
20- semantic_config : str ,
21- top = 5 ,
22- include_scores = False ,
23- minimum_score : float = None ,
24- ):
25- """Run the AI search query."""
26- identity_type = get_identity_type ()
27-
28- async with AsyncAzureOpenAI (
29- # This is the default and can be omitted
30- api_key = os .environ ["OpenAI__ApiKey" ],
31- azure_endpoint = os .environ ["OpenAI__Endpoint" ],
32- api_version = os .environ ["OpenAI__ApiVersion" ],
33- ) as open_ai_client :
34- embeddings = await open_ai_client .embeddings .create (
35- model = os .environ ["OpenAI__EmbeddingModel" ], input = query
36- )
37-
38- # Extract the embedding vector
39- embedding_vector = embeddings .data [0 ].embedding
40-
41- vector_query = VectorizedQuery (
42- vector = embedding_vector ,
43- k_nearest_neighbors = 7 ,
44- fields = "," .join (vector_fields ),
45- )
46-
47- if identity_type == IdentityType .SYSTEM_ASSIGNED :
48- credential = DefaultAzureCredential ()
49- elif identity_type == IdentityType .USER_ASSIGNED :
50- credential = DefaultAzureCredential (
51- managed_identity_client_id = os .environ ["ClientID" ]
52- )
53- else :
54- credential = AzureKeyCredential (
55- os .environ ["AIService__AzureSearchOptions__Key" ]
56- )
57- async with SearchClient (
58- endpoint = os .environ ["AIService__AzureSearchOptions__Endpoint" ],
59- index_name = index_name ,
60- credential = credential ,
61- ) as search_client :
62- results = await search_client .search (
63- top = top ,
64- semantic_configuration_name = semantic_config ,
65- search_text = query ,
66- select = "," .join (retrieval_fields ),
67- vector_queries = [vector_query ],
68- query_type = "semantic" ,
69- query_language = "en-GB" ,
70- )
71-
72- combined_results = []
73-
74- async for result in results .by_page ():
75- async for item in result :
76- if (
77- minimum_score is not None
78- and item ["@search.reranker_score" ] < minimum_score
79- ):
80- continue
81-
82- if include_scores is False :
83- del item ["@search.reranker_score" ]
84- del item ["@search.score" ]
85- del item ["@search.highlights" ]
86- del item ["@search.captions" ]
87-
88- logging .info ("Item: %s" , item )
89- combined_results .append (item )
15+ class AISearchHelper :
16+ @staticmethod
17+ async def run_ai_search_query (
18+ query ,
19+ vector_fields : list [str ],
20+ retrieval_fields : list [str ],
21+ index_name : str ,
22+ semantic_config : str ,
23+ top = 5 ,
24+ include_scores = False ,
25+ minimum_score : float = None ,
26+ ):
27+ """Run the AI search query."""
28+ identity_type = get_identity_type ()
9029
91- logging .info ("Results: %s" , combined_results )
92-
93- return combined_results
94-
95-
96- async def add_entry_to_index (document : dict , vector_fields : dict , index_name : str ):
97- """Add an entry to the search index."""
98-
99- logging .info ("Document: %s" , document )
100- logging .info ("Vector Fields: %s" , vector_fields )
101-
102- for field in vector_fields .keys ():
103- if field not in document .keys ():
104- logging .error (f"Field { field } is not in the document." )
105-
106- identity_type = get_identity_type ()
107-
108- fields_to_embed = {field : document [field ] for field in vector_fields }
109-
110- document ["DateLastModified" ] = datetime .now (timezone .utc )
111-
112- try :
11330 async with AsyncAzureOpenAI (
11431 # This is the default and can be omitted
11532 api_key = os .environ ["OpenAI__ApiKey" ],
11633 azure_endpoint = os .environ ["OpenAI__Endpoint" ],
11734 api_version = os .environ ["OpenAI__ApiVersion" ],
11835 ) as open_ai_client :
11936 embeddings = await open_ai_client .embeddings .create (
120- model = os .environ ["OpenAI__EmbeddingModel" ],
121- input = fields_to_embed .values (),
37+ model = os .environ ["OpenAI__EmbeddingModel" ], input = query
12238 )
12339
12440 # Extract the embedding vector
125- for i , field in enumerate (vector_fields .values ()):
126- document [field ] = embeddings .data [i ].embedding
41+ embedding_vector = embeddings .data [0 ].embedding
12742
128- document ["Id" ] = base64 .urlsafe_b64encode (document ["Question" ].encode ()).decode (
129- "utf-8"
43+ vector_query = VectorizedQuery (
44+ vector = embedding_vector ,
45+ k_nearest_neighbors = 7 ,
46+ fields = "," .join (vector_fields ),
13047 )
13148
13249 if identity_type == IdentityType .SYSTEM_ASSIGNED :
@@ -144,7 +61,92 @@ async def add_entry_to_index(document: dict, vector_fields: dict, index_name: st
14461 index_name = index_name ,
14562 credential = credential ,
14663 ) as search_client :
147- await search_client .upload_documents (documents = [document ])
148- except Exception as e :
149- logging .error ("Failed to add item to index." )
150- logging .error ("Error: %s" , e )
64+ results = await search_client .search (
65+ top = top ,
66+ semantic_configuration_name = semantic_config ,
67+ search_text = query ,
68+ select = "," .join (retrieval_fields ),
69+ vector_queries = [vector_query ],
70+ query_type = "semantic" ,
71+ query_language = "en-GB" ,
72+ )
73+
74+ combined_results = []
75+
76+ async for result in results .by_page ():
77+ async for item in result :
78+ if (
79+ minimum_score is not None
80+ and item ["@search.reranker_score" ] < minimum_score
81+ ):
82+ continue
83+
84+ if include_scores is False :
85+ del item ["@search.reranker_score" ]
86+ del item ["@search.score" ]
87+ del item ["@search.highlights" ]
88+ del item ["@search.captions" ]
89+
90+ logging .info ("Item: %s" , item )
91+ combined_results .append (item )
92+
93+ logging .info ("Results: %s" , combined_results )
94+
95+ return combined_results
96+
97+ @staticmethod
98+ async def add_entry_to_index (document : dict , vector_fields : dict , index_name : str ):
99+ """Add an entry to the search index."""
100+
101+ logging .info ("Document: %s" , document )
102+ logging .info ("Vector Fields: %s" , vector_fields )
103+
104+ for field in vector_fields .keys ():
105+ if field not in document .keys ():
106+ logging .error (f"Field { field } is not in the document." )
107+
108+ identity_type = get_identity_type ()
109+
110+ fields_to_embed = {field : document [field ] for field in vector_fields }
111+
112+ document ["DateLastModified" ] = datetime .now (timezone .utc )
113+
114+ try :
115+ async with AsyncAzureOpenAI (
116+ # This is the default and can be omitted
117+ api_key = os .environ ["OpenAI__ApiKey" ],
118+ azure_endpoint = os .environ ["OpenAI__Endpoint" ],
119+ api_version = os .environ ["OpenAI__ApiVersion" ],
120+ ) as open_ai_client :
121+ embeddings = await open_ai_client .embeddings .create (
122+ model = os .environ ["OpenAI__EmbeddingModel" ],
123+ input = fields_to_embed .values (),
124+ )
125+
126+ # Extract the embedding vector
127+ for i , field in enumerate (vector_fields .values ()):
128+ document [field ] = embeddings .data [i ].embedding
129+
130+ document ["Id" ] = base64 .urlsafe_b64encode (
131+ document ["Question" ].encode ()
132+ ).decode ("utf-8" )
133+
134+ if identity_type == IdentityType .SYSTEM_ASSIGNED :
135+ credential = DefaultAzureCredential ()
136+ elif identity_type == IdentityType .USER_ASSIGNED :
137+ credential = DefaultAzureCredential (
138+ managed_identity_client_id = os .environ ["ClientID" ]
139+ )
140+ else :
141+ credential = AzureKeyCredential (
142+ os .environ ["AIService__AzureSearchOptions__Key" ]
143+ )
144+ async with SearchClient (
145+ endpoint = os .environ ["AIService__AzureSearchOptions__Endpoint" ],
146+ index_name = index_name ,
147+ credential = credential ,
148+ ) as search_client :
149+ await search_client .upload_documents (documents = [document ])
150+ except Exception as e :
151+ logging .error ("Failed to add item to index." )
152+ logging .error ("Error: %s" , e )
0 commit comments