@@ -116,7 +116,9 @@ def from_completion_usage(cls, usage: CompletionUsage) -> "TokenUsageProps":
116116 prompt_tokens = usage .prompt_tokens ,
117117 completion_tokens = usage .completion_tokens ,
118118 reasoning_tokens = (
119- usage .completion_tokens_details .reasoning_tokens if usage .completion_tokens_details else None
119+ usage .completion_tokens_details .reasoning_tokens
120+ if usage .completion_tokens_details
121+ else None
120122 ),
121123 total_tokens = usage .total_tokens ,
122124 )
@@ -148,7 +150,9 @@ def __init__(
148150 auth_helper : AuthenticationHelper ,
149151 query_language : Optional [str ],
150152 query_speller : Optional [str ],
151- embedding_deployment : Optional [str ], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
153+ embedding_deployment : Optional [
154+ str
155+ ], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
152156 embedding_model : str ,
153157 embedding_dimensions : int ,
154158 embedding_field : str ,
@@ -174,15 +178,23 @@ def __init__(
174178 self .reasoning_effort = reasoning_effort
175179 self .include_token_usage = True
176180
177- def build_filter (self , overrides : dict [str , Any ], auth_claims : dict [str , Any ]) -> Optional [str ]:
181+ def build_filter (
182+ self , overrides : dict [str , Any ], auth_claims : dict [str , Any ]
183+ ) -> Optional [str ]:
178184 include_category = overrides .get ("include_category" )
179185 exclude_category = overrides .get ("exclude_category" )
180- security_filter = self .auth_helper .build_security_filters (overrides , auth_claims )
186+ security_filter = self .auth_helper .build_security_filters (
187+ overrides , auth_claims
188+ )
181189 filters = []
182190 if include_category :
183- filters .append ("category eq '{}'" .format (include_category .replace ("'" , "''" )))
191+ filters .append (
192+ "category eq '{}'" .format (include_category .replace ("'" , "''" ))
193+ )
184194 if exclude_category :
185- filters .append ("category ne '{}'" .format (exclude_category .replace ("'" , "''" )))
195+ filters .append (
196+ "category ne '{}'" .format (exclude_category .replace ("'" , "''" ))
197+ )
186198 if security_filter :
187199 filters .append (security_filter )
188200 return None if len (filters ) == 0 else " and " .join (filters )
@@ -208,7 +220,9 @@ async def search(
208220 search_text = search_text ,
209221 filter = filter ,
210222 top = top ,
211- query_caption = "extractive|highlight-false" if use_semantic_captions else None ,
223+ query_caption = "extractive|highlight-false"
224+ if use_semantic_captions
225+ else None ,
212226 query_rewrites = "generative" if use_query_rewriting else None ,
213227 vector_queries = search_vectors ,
214228 query_type = QueryType .SEMANTIC ,
@@ -237,7 +251,9 @@ async def search(
237251 sourcefile = document .get ("sourcefile" ),
238252 oids = document .get ("oids" ),
239253 groups = document .get ("groups" ),
240- captions = cast (list [QueryCaptionResult ], document .get ("@search.captions" )),
254+ captions = cast (
255+ list [QueryCaptionResult ], document .get ("@search.captions" )
256+ ),
241257 score = document .get ("@search.score" ),
242258 reranker_score = document .get ("@search.reranker_score" ),
243259 )
@@ -270,7 +286,10 @@ async def run_agentic_retrieval(
270286 retrieval_request = KnowledgeAgentRetrievalRequest (
271287 messages = [
272288 KnowledgeAgentMessage (
273- role = str (msg ["role" ]), content = [KnowledgeAgentMessageTextContent (text = str (msg ["content" ]))]
289+ role = str (msg ["role" ]),
290+ content = [
291+ KnowledgeAgentMessageTextContent (text = str (msg ["content" ]))
292+ ],
274293 )
275294 for msg in messages
276295 if msg ["role" ] != "system"
@@ -303,18 +322,25 @@ async def run_agentic_retrieval(
303322 if response and response .references :
304323 if results_merge_strategy == "interleaved" :
305324 # Use interleaved reference order
306- references = sorted (response .references , key = lambda reference : int (reference .id ))
325+ references = sorted (
326+ response .references , key = lambda reference : int (reference .id )
327+ )
307328 else :
308329 # Default to descending strategy
309330 references = response .references
310331 for reference in references :
311- if isinstance (reference , KnowledgeAgentAzureSearchDocReference ) and reference .source_data :
332+ if (
333+ isinstance (reference , KnowledgeAgentAzureSearchDocReference )
334+ and reference .source_data
335+ ):
312336 results .append (
313337 Document (
314338 id = reference .doc_key ,
315339 content = reference .source_data ["content" ],
316340 sourcepage = reference .source_data ["sourcepage" ],
317- search_agent_query = activity_mapping [reference .activity_source ],
341+ search_agent_query = activity_mapping [
342+ reference .activity_source
343+ ],
318344 )
319345 )
320346 if top and len (results ) == top :
@@ -323,22 +349,28 @@ async def run_agentic_retrieval(
323349 return response , results
324350
325351 def get_sources_content (
326- self , results : list [Document ], use_semantic_captions : bool , use_image_citation : bool
352+ self ,
353+ results : list [Document ],
354+ use_semantic_captions : bool ,
355+ use_image_citation : bool ,
327356 ) -> list [str ]:
328-
329357 def nonewlines (s : str ) -> str :
330358 return s .replace ("\n " , " " ).replace ("\r " , " " )
331359
332360 if use_semantic_captions :
333361 return [
334362 (self .get_citation ((doc .sourcepage or "" ), use_image_citation ))
335363 + ": "
336- + nonewlines (" . " .join ([cast (str , c .text ) for c in (doc .captions or [])]))
364+ + nonewlines (
365+ " . " .join ([cast (str , c .text ) for c in (doc .captions or [])])
366+ )
337367 for doc in results
338368 ]
339369 else :
340370 return [
341- (self .get_citation ((doc .sourcepage or "" ), use_image_citation )) + ": " + nonewlines (doc .content or "" )
371+ (self .get_citation ((doc .sourcepage or "" ), use_image_citation ))
372+ + ": "
373+ + nonewlines (doc .content or "" )
342374 for doc in results
343375 ]
344376
@@ -365,21 +397,29 @@ class ExtraArgs(TypedDict, total=False):
365397 dimensions : int
366398
367399 dimensions_args : ExtraArgs = (
368- {"dimensions" : self .embedding_dimensions } if SUPPORTED_DIMENSIONS_MODEL [self .embedding_model ] else {}
400+ {"dimensions" : self .embedding_dimensions }
401+ if SUPPORTED_DIMENSIONS_MODEL [self .embedding_model ]
402+ else {}
369403 )
370404 embedding = await self .openai_client .embeddings .create (
371405 # Azure OpenAI takes the deployment name as the model name
372- model = self .embedding_deployment if self .embedding_deployment else self .embedding_model ,
406+ model = self .embedding_deployment
407+ if self .embedding_deployment
408+ else self .embedding_model ,
373409 input = q ,
374410 ** dimensions_args ,
375411 )
376412 query_vector = embedding .data [0 ].embedding
377413 # This performs an oversampling due to how the search index was setup,
378414 # so we do not need to explicitly pass in an oversampling parameter here
379- return VectorizedQuery (vector = query_vector , k_nearest_neighbors = 50 , fields = self .embedding_field )
415+ return VectorizedQuery (
416+ vector = query_vector , k_nearest_neighbors = 50 , fields = self .embedding_field
417+ )
380418
381419 async def compute_image_embedding (self , q : str ):
382- endpoint = urljoin (self .vision_endpoint , "computervision/retrieval:vectorizeText" )
420+ endpoint = urljoin (
421+ self .vision_endpoint , "computervision/retrieval:vectorizeText"
422+ )
383423 headers = {"Content-Type" : "application/json" }
384424 params = {"api-version" : "2024-02-01" , "model-version" : "2023-04-15" }
385425 data = {"text" : q }
@@ -388,13 +428,21 @@ async def compute_image_embedding(self, q: str):
388428
389429 async with aiohttp .ClientSession () as session :
390430 async with session .post (
391- url = endpoint , params = params , headers = headers , json = data , raise_for_status = True
431+ url = endpoint ,
432+ params = params ,
433+ headers = headers ,
434+ json = data ,
435+ raise_for_status = True ,
392436 ) as response :
393437 json = await response .json ()
394438 image_query_vector = json ["vector" ]
395- return VectorizedQuery (vector = image_query_vector , k_nearest_neighbors = 50 , fields = "imageEmbedding" )
439+ return VectorizedQuery (
440+ vector = image_query_vector , k_nearest_neighbors = 50 , fields = "imageEmbedding"
441+ )
396442
397- def get_system_prompt_variables (self , override_prompt : Optional [str ]) -> dict [str , str ]:
443+ def get_system_prompt_variables (
444+ self , override_prompt : Optional [str ]
445+ ) -> dict [str , str ]:
398446 # Allows client to replace the entire prompt, or to inject into the existing prompt using >>>
399447 if override_prompt is None :
400448 return {}
@@ -433,7 +481,11 @@ def create_chat_completion(
433481 if supported_features .streaming and should_stream :
434482 params ["stream" ] = True
435483 params ["stream_options" ] = {"include_usage" : True }
436- params ["reasoning_effort" ] = reasoning_effort or overrides .get ("reasoning_effort" ) or self .reasoning_effort
484+ params ["reasoning_effort" ] = (
485+ reasoning_effort
486+ or overrides .get ("reasoning_effort" )
487+ or self .reasoning_effort
488+ )
437489
438490 else :
439491 # Include parameters that may not be supported for reasoning models
0 commit comments