Skip to content

Commit 6c1f9e4

Browse files
committed
Refactor code for improved readability and maintainability
- Enhanced logging messages for better clarity in various modules. - Reformatted code to adhere to PEP 8 style guidelines, including line breaks and indentation. - Updated constructor definitions and method calls to improve readability. - Adjusted list comprehensions and generator expressions for better clarity. - Ensured consistent formatting across multiple files, including spacing and line lengths. - Modified the `pyproject.toml` file to correct source paths for linting and testing.
1 parent 0603f84 commit 6c1f9e4

28 files changed

+1275
-432
lines changed

app/hrchatbot/backend/app.py

Lines changed: 245 additions & 94 deletions
Large diffs are not rendered by default.

app/hrchatbot/backend/approaches/approach.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def from_completion_usage(cls, usage: CompletionUsage) -> "TokenUsageProps":
116116
prompt_tokens=usage.prompt_tokens,
117117
completion_tokens=usage.completion_tokens,
118118
reasoning_tokens=(
119-
usage.completion_tokens_details.reasoning_tokens if usage.completion_tokens_details else None
119+
usage.completion_tokens_details.reasoning_tokens
120+
if usage.completion_tokens_details
121+
else None
120122
),
121123
total_tokens=usage.total_tokens,
122124
)
@@ -148,7 +150,9 @@ def __init__(
148150
auth_helper: AuthenticationHelper,
149151
query_language: Optional[str],
150152
query_speller: Optional[str],
151-
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
153+
embedding_deployment: Optional[
154+
str
155+
], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
152156
embedding_model: str,
153157
embedding_dimensions: int,
154158
embedding_field: str,
@@ -174,15 +178,23 @@ def __init__(
174178
self.reasoning_effort = reasoning_effort
175179
self.include_token_usage = True
176180

177-
def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Optional[str]:
181+
def build_filter(
182+
self, overrides: dict[str, Any], auth_claims: dict[str, Any]
183+
) -> Optional[str]:
178184
include_category = overrides.get("include_category")
179185
exclude_category = overrides.get("exclude_category")
180-
security_filter = self.auth_helper.build_security_filters(overrides, auth_claims)
186+
security_filter = self.auth_helper.build_security_filters(
187+
overrides, auth_claims
188+
)
181189
filters = []
182190
if include_category:
183-
filters.append("category eq '{}'".format(include_category.replace("'", "''")))
191+
filters.append(
192+
"category eq '{}'".format(include_category.replace("'", "''"))
193+
)
184194
if exclude_category:
185-
filters.append("category ne '{}'".format(exclude_category.replace("'", "''")))
195+
filters.append(
196+
"category ne '{}'".format(exclude_category.replace("'", "''"))
197+
)
186198
if security_filter:
187199
filters.append(security_filter)
188200
return None if len(filters) == 0 else " and ".join(filters)
@@ -208,7 +220,9 @@ async def search(
208220
search_text=search_text,
209221
filter=filter,
210222
top=top,
211-
query_caption="extractive|highlight-false" if use_semantic_captions else None,
223+
query_caption="extractive|highlight-false"
224+
if use_semantic_captions
225+
else None,
212226
query_rewrites="generative" if use_query_rewriting else None,
213227
vector_queries=search_vectors,
214228
query_type=QueryType.SEMANTIC,
@@ -237,7 +251,9 @@ async def search(
237251
sourcefile=document.get("sourcefile"),
238252
oids=document.get("oids"),
239253
groups=document.get("groups"),
240-
captions=cast(list[QueryCaptionResult], document.get("@search.captions")),
254+
captions=cast(
255+
list[QueryCaptionResult], document.get("@search.captions")
256+
),
241257
score=document.get("@search.score"),
242258
reranker_score=document.get("@search.reranker_score"),
243259
)
@@ -270,7 +286,10 @@ async def run_agentic_retrieval(
270286
retrieval_request=KnowledgeAgentRetrievalRequest(
271287
messages=[
272288
KnowledgeAgentMessage(
273-
role=str(msg["role"]), content=[KnowledgeAgentMessageTextContent(text=str(msg["content"]))]
289+
role=str(msg["role"]),
290+
content=[
291+
KnowledgeAgentMessageTextContent(text=str(msg["content"]))
292+
],
274293
)
275294
for msg in messages
276295
if msg["role"] != "system"
@@ -303,18 +322,25 @@ async def run_agentic_retrieval(
303322
if response and response.references:
304323
if results_merge_strategy == "interleaved":
305324
# Use interleaved reference order
306-
references = sorted(response.references, key=lambda reference: int(reference.id))
325+
references = sorted(
326+
response.references, key=lambda reference: int(reference.id)
327+
)
307328
else:
308329
# Default to descending strategy
309330
references = response.references
310331
for reference in references:
311-
if isinstance(reference, KnowledgeAgentAzureSearchDocReference) and reference.source_data:
332+
if (
333+
isinstance(reference, KnowledgeAgentAzureSearchDocReference)
334+
and reference.source_data
335+
):
312336
results.append(
313337
Document(
314338
id=reference.doc_key,
315339
content=reference.source_data["content"],
316340
sourcepage=reference.source_data["sourcepage"],
317-
search_agent_query=activity_mapping[reference.activity_source],
341+
search_agent_query=activity_mapping[
342+
reference.activity_source
343+
],
318344
)
319345
)
320346
if top and len(results) == top:
@@ -323,22 +349,28 @@ async def run_agentic_retrieval(
323349
return response, results
324350

325351
def get_sources_content(
326-
self, results: list[Document], use_semantic_captions: bool, use_image_citation: bool
352+
self,
353+
results: list[Document],
354+
use_semantic_captions: bool,
355+
use_image_citation: bool,
327356
) -> list[str]:
328-
329357
def nonewlines(s: str) -> str:
330358
return s.replace("\n", " ").replace("\r", " ")
331359

332360
if use_semantic_captions:
333361
return [
334362
(self.get_citation((doc.sourcepage or ""), use_image_citation))
335363
+ ": "
336-
+ nonewlines(" . ".join([cast(str, c.text) for c in (doc.captions or [])]))
364+
+ nonewlines(
365+
" . ".join([cast(str, c.text) for c in (doc.captions or [])])
366+
)
337367
for doc in results
338368
]
339369
else:
340370
return [
341-
(self.get_citation((doc.sourcepage or ""), use_image_citation)) + ": " + nonewlines(doc.content or "")
371+
(self.get_citation((doc.sourcepage or ""), use_image_citation))
372+
+ ": "
373+
+ nonewlines(doc.content or "")
342374
for doc in results
343375
]
344376

@@ -365,21 +397,29 @@ class ExtraArgs(TypedDict, total=False):
365397
dimensions: int
366398

367399
dimensions_args: ExtraArgs = (
368-
{"dimensions": self.embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL[self.embedding_model] else {}
400+
{"dimensions": self.embedding_dimensions}
401+
if SUPPORTED_DIMENSIONS_MODEL[self.embedding_model]
402+
else {}
369403
)
370404
embedding = await self.openai_client.embeddings.create(
371405
# Azure OpenAI takes the deployment name as the model name
372-
model=self.embedding_deployment if self.embedding_deployment else self.embedding_model,
406+
model=self.embedding_deployment
407+
if self.embedding_deployment
408+
else self.embedding_model,
373409
input=q,
374410
**dimensions_args,
375411
)
376412
query_vector = embedding.data[0].embedding
377413
# This performs an oversampling due to how the search index was setup,
378414
# so we do not need to explicitly pass in an oversampling parameter here
379-
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field)
415+
return VectorizedQuery(
416+
vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field
417+
)
380418

381419
async def compute_image_embedding(self, q: str):
382-
endpoint = urljoin(self.vision_endpoint, "computervision/retrieval:vectorizeText")
420+
endpoint = urljoin(
421+
self.vision_endpoint, "computervision/retrieval:vectorizeText"
422+
)
383423
headers = {"Content-Type": "application/json"}
384424
params = {"api-version": "2024-02-01", "model-version": "2023-04-15"}
385425
data = {"text": q}
@@ -388,13 +428,21 @@ async def compute_image_embedding(self, q: str):
388428

389429
async with aiohttp.ClientSession() as session:
390430
async with session.post(
391-
url=endpoint, params=params, headers=headers, json=data, raise_for_status=True
431+
url=endpoint,
432+
params=params,
433+
headers=headers,
434+
json=data,
435+
raise_for_status=True,
392436
) as response:
393437
json = await response.json()
394438
image_query_vector = json["vector"]
395-
return VectorizedQuery(vector=image_query_vector, k_nearest_neighbors=50, fields="imageEmbedding")
439+
return VectorizedQuery(
440+
vector=image_query_vector, k_nearest_neighbors=50, fields="imageEmbedding"
441+
)
396442

397-
def get_system_prompt_variables(self, override_prompt: Optional[str]) -> dict[str, str]:
443+
def get_system_prompt_variables(
444+
self, override_prompt: Optional[str]
445+
) -> dict[str, str]:
398446
# Allows client to replace the entire prompt, or to inject into the existing prompt using >>>
399447
if override_prompt is None:
400448
return {}
@@ -433,7 +481,11 @@ def create_chat_completion(
433481
if supported_features.streaming and should_stream:
434482
params["stream"] = True
435483
params["stream_options"] = {"include_usage": True}
436-
params["reasoning_effort"] = reasoning_effort or overrides.get("reasoning_effort") or self.reasoning_effort
484+
params["reasoning_effort"] = (
485+
reasoning_effort
486+
or overrides.get("reasoning_effort")
487+
or self.reasoning_effort
488+
)
437489

438490
else:
439491
# Include parameters that may not be supported for reasoning models

app/hrchatbot/backend/approaches/chatapproach.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919

2020
class ChatApproach(Approach, ABC):
21-
2221
NO_RESPONSE = "0"
2322

2423
@abstractmethod
2524
async def run_until_final_call(
2625
self, messages, overrides, auth_claims, should_stream
27-
) -> tuple[ExtraInfo, Union[Awaitable[ChatCompletion], Awaitable[AsyncStream[ChatCompletionChunk]]]]:
26+
) -> tuple[
27+
ExtraInfo,
28+
Union[Awaitable[ChatCompletion], Awaitable[AsyncStream[ChatCompletionChunk]]],
29+
]:
2830
pass
2931

3032
def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
@@ -60,14 +62,20 @@ async def run_without_streaming(
6062
extra_info, chat_coroutine = await self.run_until_final_call(
6163
messages, overrides, auth_claims, should_stream=False
6264
)
63-
chat_completion_response: ChatCompletion = await cast(Awaitable[ChatCompletion], chat_coroutine)
65+
chat_completion_response: ChatCompletion = await cast(
66+
Awaitable[ChatCompletion], chat_coroutine
67+
)
6468
content = chat_completion_response.choices[0].message.content
6569
role = chat_completion_response.choices[0].message.role
6670
if overrides.get("suggest_followup_questions"):
6771
content, followup_questions = self.extract_followup_questions(content)
6872
extra_info.followup_questions = followup_questions
6973
# Assume last thought is for generating answer
70-
if self.include_token_usage and extra_info.thoughts and chat_completion_response.usage:
74+
if (
75+
self.include_token_usage
76+
and extra_info.thoughts
77+
and chat_completion_response.usage
78+
):
7179
extra_info.thoughts[-1].update_token_usage(chat_completion_response.usage)
7280
chat_app_response = {
7381
"message": {"content": content, "role": role},
@@ -86,8 +94,14 @@ async def run_with_streaming(
8694
extra_info, chat_coroutine = await self.run_until_final_call(
8795
messages, overrides, auth_claims, should_stream=True
8896
)
89-
chat_coroutine = cast(Awaitable[AsyncStream[ChatCompletionChunk]], chat_coroutine)
90-
yield {"delta": {"role": "assistant"}, "context": extra_info, "session_state": session_state}
97+
chat_coroutine = cast(
98+
Awaitable[AsyncStream[ChatCompletionChunk]], chat_coroutine
99+
)
100+
yield {
101+
"delta": {"role": "assistant"},
102+
"context": extra_info,
103+
"session_state": session_state,
104+
}
91105

92106
followup_questions_started = False
93107
followup_content = ""
@@ -104,7 +118,9 @@ async def run_with_streaming(
104118
}
105119
# if event contains << and not >>, it is start of follow-up question, truncate
106120
content = completion["delta"].get("content")
107-
content = content or "" # content may either not exist in delta, or explicitly be None
121+
content = (
122+
content or ""
123+
) # content may either not exist in delta, or explicitly be None
108124
if overrides.get("suggest_followup_questions") and "<<" in content:
109125
followup_questions_started = True
110126
earlier_content = content[: content.index("<<")]
@@ -119,15 +135,26 @@ async def run_with_streaming(
119135
else:
120136
# Final chunk at end of streaming should contain usage
121137
# https://cookbook.openai.com/examples/how_to_stream_completions#4-how-to-get-token-usage-data-for-streamed-chat-completion-response
122-
if event_chunk.usage and extra_info.thoughts and self.include_token_usage:
138+
if (
139+
event_chunk.usage
140+
and extra_info.thoughts
141+
and self.include_token_usage
142+
):
123143
extra_info.thoughts[-1].update_token_usage(event_chunk.usage)
124-
yield {"delta": {"role": "assistant"}, "context": extra_info, "session_state": session_state}
144+
yield {
145+
"delta": {"role": "assistant"},
146+
"context": extra_info,
147+
"session_state": session_state,
148+
}
125149

126150
if followup_content:
127151
_, followup_questions = self.extract_followup_questions(followup_content)
128152
yield {
129153
"delta": {"role": "assistant"},
130-
"context": {"context": extra_info, "followup_questions": followup_questions},
154+
"context": {
155+
"context": extra_info,
156+
"followup_questions": followup_questions,
157+
},
131158
}
132159

133160
async def run(
@@ -138,7 +165,9 @@ async def run(
138165
) -> dict[str, Any]:
139166
overrides = context.get("overrides", {})
140167
auth_claims = context.get("auth_claims", {})
141-
return await self.run_without_streaming(messages, overrides, auth_claims, session_state)
168+
return await self.run_without_streaming(
169+
messages, overrides, auth_claims, session_state
170+
)
142171

143172
async def run_stream(
144173
self,

0 commit comments

Comments
 (0)