Skip to content

Commit 58b9325

Browse files
authored
Reduce coupling (#613)
In attachment retrieval and setup for summarizers
1 parent 8051502 commit 58b9325

File tree

11 files changed

+117
-73
lines changed

11 files changed

+117
-73
lines changed

assistants/codespace-assistant/assistant/response/request_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def build_request(
148148
model=request_config.model,
149149
)
150150

151-
logging.info(
151+
logger.info(
152152
"chat message params budgeted; message count: %d, total token count: %d",
153153
len(chat_message_params),
154154
total_token_count,

assistants/codespace-assistant/assistant/response/response.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
from contextlib import AsyncExitStack
33
from typing import Any
44

5-
from assistant_extensions.chat_context_toolkit.archive import ArchiveTaskQueues
5+
from assistant_extensions.attachments import get_attachments
6+
from assistant_extensions.chat_context_toolkit.archive import (
7+
ArchiveTaskQueues,
8+
construct_archive_summarizer,
9+
)
10+
from assistant_extensions.chat_context_toolkit.message_history import (
11+
construct_attachment_summarizer,
12+
)
613
from assistant_extensions.mcp import (
714
MCPClientSettings,
815
MCPServerConnectionError,
@@ -166,8 +173,19 @@ async def message_handler(message) -> None:
166173
# enqueue an archive task for this conversation
167174
await archive_task_queues.enqueue_run(
168175
context=context,
169-
service_config=service_config,
170-
request_config=request_config,
176+
attachments=list(
177+
await get_attachments(
178+
context,
179+
summarizer=construct_attachment_summarizer(
180+
service_config=service_config,
181+
request_config=request_config,
182+
),
183+
)
184+
),
185+
archive_summarizer=construct_archive_summarizer(
186+
service_config=service_config,
187+
request_config=request_config,
188+
),
171189
archive_task_config=ArchiveTaskConfig(
172190
chunk_token_count_threshold=config.chat_context_config.archive_token_threshold
173191
),

assistants/codespace-assistant/assistant/response/step_handler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from typing import Any, List
55

66
import deepmerge
7-
from assistant_extensions.chat_context_toolkit.message_history import chat_context_toolkit_message_provider_for
7+
from assistant_extensions.attachments import get_attachments
8+
from assistant_extensions.chat_context_toolkit.message_history import (
9+
chat_context_toolkit_message_provider_for,
10+
construct_attachment_summarizer,
11+
)
812
from assistant_extensions.chat_context_toolkit.virtual_filesystem import (
913
archive_file_source_mount,
1014
attachments_file_source_mount,
@@ -100,8 +104,15 @@ async def handle_error(error_message: str, error_debug: dict[str, Any] | None =
100104
history_message_provider = chat_context_toolkit_message_provider_for(
101105
context=context,
102106
tool_abbreviations=abbreviations.tool_abbreviations,
103-
service_config=service_config,
104-
request_config=request_config,
107+
attachments=list(
108+
await get_attachments(
109+
context,
110+
summarizer=construct_attachment_summarizer(
111+
service_config=service_config,
112+
request_config=request_config,
113+
),
114+
)
115+
),
105116
)
106117

107118
build_request_result = await build_request(

assistants/document-assistant/assistant/response/responder.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
import deepmerge
1111
import pendulum
12-
from assistant_extensions.chat_context_toolkit.message_history import chat_context_toolkit_message_provider_for
12+
from assistant_extensions.attachments import get_attachments
13+
from assistant_extensions.chat_context_toolkit.message_history import (
14+
chat_context_toolkit_message_provider_for,
15+
construct_attachment_summarizer,
16+
)
1317
from assistant_extensions.chat_context_toolkit.virtual_filesystem import (
1418
archive_file_source_mount,
1519
)
@@ -400,9 +404,15 @@ async def _construct_prompt(self) -> tuple[list, list[ChatCompletionMessageParam
400404
message_provider = chat_context_toolkit_message_provider_for(
401405
context=self.context,
402406
tool_abbreviations=tool_abbreviations,
403-
# use the fast client config for the attachment summarization that the message provider does
404-
service_config=self.config.generative_ai_fast_client_config.service_config,
405-
request_config=self.config.generative_ai_fast_client_config.request_config,
407+
attachments=list(
408+
await get_attachments(
409+
self.context,
410+
summarizer=construct_attachment_summarizer(
411+
service_config=self.config.generative_ai_fast_client_config.service_config,
412+
request_config=self.config.generative_ai_fast_client_config.request_config,
413+
),
414+
)
415+
),
406416
)
407417
system_prompt_token_count = num_tokens_from_message(main_system_prompt, model="gpt-4o")
408418
tool_token_count = num_tokens_from_tools(tools, model="gpt-4o")
@@ -421,7 +431,7 @@ async def _construct_prompt(self) -> tuple[list, list[ChatCompletionMessageParam
421431
chat_history: list[ChatCompletionMessageParam] = list(budgeted_messages_result.messages)
422432
chat_history.insert(0, main_system_prompt)
423433

424-
logging.info("The system prompt has been constructed.")
434+
logger.info("The system prompt has been constructed.")
425435
# Update telemetry for inspector
426436
self.latest_telemetry.system_prompt_tokens = system_prompt_token_count
427437
self.latest_telemetry.tool_tokens = tool_token_count

libraries/python/assistant-extensions/assistant_extensions/attachments/_attachments.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def get_completion_messages_for_attachments(
118118
self,
119119
context: ConversationContext,
120120
config: AttachmentsConfigModel,
121-
include_filenames: list[str] | None = None,
121+
include_filenames: list[str] = [],
122122
exclude_filenames: list[str] = [],
123123
summarizer: Summarizer | None = None,
124124
) -> Sequence[CompletionMessage]:
@@ -143,6 +143,7 @@ async def get_completion_messages_for_attachments(
143143
error_handler=self._error_handler,
144144
include_filenames=include_filenames,
145145
exclude_filenames=exclude_filenames,
146+
summarizer=summarizer,
146147
)
147148

148149
if not attachments:
@@ -159,14 +160,14 @@ async def get_completion_messages_for_attachments(
159160
async def get_attachment_filenames(
160161
self,
161162
context: ConversationContext,
162-
include_filenames: list[str] | None = None,
163+
include_filenames: list[str] = [],
163164
exclude_filenames: list[str] = [],
164165
) -> list[str]:
165166
files_response = await context.list_files()
166167

167168
# for all files, get the attachment
168169
for file in files_response.files:
169-
if include_filenames is not None and file.filename not in include_filenames:
170+
if include_filenames and file.filename not in include_filenames:
170171
continue
171172
if file.filename in exclude_filenames:
172173
continue
@@ -226,33 +227,33 @@ async def default_error_handler(context: ConversationContext, filename: str, e:
226227

227228
async def get_attachments(
228229
context: ConversationContext,
229-
include_filenames: list[str] | None,
230-
exclude_filenames: list[str],
230+
exclude_filenames: list[str] = [],
231+
include_filenames: list[str] = [],
231232
error_handler: AttachmentProcessingErrorHandler = default_error_handler,
232233
summarizer: Summarizer | None = None,
233-
) -> Sequence[Attachment]:
234+
) -> list[Attachment]:
234235
"""
235236
Gets all attachments for the current state of the conversation, updating the cache as needed.
236237
"""
237238

238239
# get all files in the conversation
239240
files_response = await context.list_files()
240241

242+
# delete cached attachments that are no longer in the conversation
243+
filenames = {file.filename for file in files_response.files}
244+
asyncio.create_task(_delete_attachments_not_in(context, filenames))
245+
241246
attachments = []
242247
# for all files, get the attachment
243248
for file in files_response.files:
244-
if include_filenames is not None and file.filename not in include_filenames:
249+
if include_filenames and file.filename not in include_filenames:
245250
continue
246251
if file.filename in exclude_filenames:
247252
continue
248253

249254
attachment = await _get_attachment_for_file(context, file, {}, error_handler, summarizer=summarizer)
250255
attachments.append(attachment)
251256

252-
# delete cached attachments that are no longer in the conversation
253-
filenames = {file.filename for file in files_response.files}
254-
await _delete_attachments_not_in(context, filenames)
255-
256257
return attachments
257258

258259

libraries/python/assistant-extensions/assistant_extensions/chat_context_toolkit/archive/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Provides the ArchiveTaskQueues class, for integrating with the chat context toolkit's archiving functionality.
33
"""
44

5-
from ._archive import ArchiveTaskQueues
5+
from ._archive import ArchiveTaskQueues, construct_archive_summarizer
66

77
__all__ = [
88
"ArchiveTaskQueues",
9+
"construct_archive_summarizer",
910
]

libraries/python/assistant-extensions/assistant_extensions/chat_context_toolkit/archive/_archive.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from chat_context_toolkit.archive import ArchiveReader, ArchiveTaskConfig, ArchiveTaskQueue, StorageProvider
44
from chat_context_toolkit.archive import MessageProvider as ArchiveMessageProvider
55
from chat_context_toolkit.archive.summarization import LLMArchiveSummarizer, LLMArchiveSummarizerConfig
6-
from chat_context_toolkit.history.tool_abbreviations import ToolAbbreviations
76
from openai_client import OpenAIRequestConfig, ServiceConfig, create_client
87
from openai_client.tokens import num_tokens_from_messages
98
from semantic_workbench_assistant.assistant_app import ConversationContext, storage_directory_for_context
109

10+
from assistant_extensions.attachments._model import Attachment
11+
1112
from ..message_history import chat_context_toolkit_message_provider_for
1213

1314

@@ -46,21 +47,30 @@ async def list_files(self, relative_directory_path: PurePath) -> list[PurePath]:
4647

4748

4849
def archive_message_provider_for(
49-
context: ConversationContext, service_config: ServiceConfig, request_config: OpenAIRequestConfig
50+
context: ConversationContext,
51+
attachments: list[Attachment],
5052
) -> ArchiveMessageProvider:
5153
"""Create an archive message provider for the provided context."""
5254
return chat_context_toolkit_message_provider_for(
5355
context=context,
54-
tool_abbreviations=ToolAbbreviations(),
55-
service_config=service_config,
56-
request_config=request_config,
56+
attachments=attachments,
5757
)
5858

5959

60-
def _archive_task_queue_for(
61-
context: ConversationContext,
60+
def construct_archive_summarizer(
6261
service_config: ServiceConfig,
6362
request_config: OpenAIRequestConfig,
63+
) -> LLMArchiveSummarizer:
64+
return LLMArchiveSummarizer(
65+
client_factory=lambda: create_client(service_config),
66+
llm_config=LLMArchiveSummarizerConfig(model=request_config.model),
67+
)
68+
69+
70+
def _archive_task_queue_for(
71+
context: ConversationContext,
72+
attachments: list[Attachment],
73+
archive_summarizer: LLMArchiveSummarizer,
6474
archive_task_config: ArchiveTaskConfig = ArchiveTaskConfig(),
6575
token_counting_model: str = "gpt-4o",
6676
archive_storage_sub_directory: str = "archives",
@@ -71,13 +81,11 @@ def _archive_task_queue_for(
7181
return ArchiveTaskQueue(
7282
storage_provider=ArchiveStorageProvider(context=context, sub_directory=archive_storage_sub_directory),
7383
message_provider=archive_message_provider_for(
74-
context=context, service_config=service_config, request_config=request_config
84+
context=context,
85+
attachments=attachments,
7586
),
7687
token_counter=lambda messages: num_tokens_from_messages(messages=messages, model=token_counting_model),
77-
summarizer=LLMArchiveSummarizer(
78-
client_factory=lambda: create_client(service_config),
79-
llm_config=LLMArchiveSummarizerConfig(model=request_config.model),
80-
),
88+
summarizer=archive_summarizer,
8189
config=archive_task_config,
8290
)
8391

@@ -93,17 +101,17 @@ def __init__(self) -> None:
93101
async def enqueue_run(
94102
self,
95103
context: ConversationContext,
96-
service_config: ServiceConfig,
97-
request_config: OpenAIRequestConfig,
104+
attachments: list[Attachment],
105+
archive_summarizer: LLMArchiveSummarizer,
98106
archive_task_config: ArchiveTaskConfig = ArchiveTaskConfig(),
99107
) -> None:
100108
"""Get the archive task queue for the given context, creating it if it does not exist."""
101109
context_id = context.id
102110
if context_id not in self._queues:
103111
self._queues[context_id] = _archive_task_queue_for(
104112
context=context,
105-
service_config=service_config,
106-
request_config=request_config,
113+
attachments=attachments,
114+
archive_summarizer=archive_summarizer,
107115
archive_task_config=archive_task_config,
108116
)
109117
await self._queues[context_id].enqueue_run()

libraries/python/assistant-extensions/assistant_extensions/chat_context_toolkit/message_history/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Provides a message history provider for the chat context toolkit's history management.
33
"""
44

5-
from ._history import chat_context_toolkit_message_provider_for
5+
from ._history import chat_context_toolkit_message_provider_for, construct_attachment_summarizer
66

77
__all__ = [
88
"chat_context_toolkit_message_provider_for",
9+
"construct_attachment_summarizer",
910
]

libraries/python/assistant-extensions/assistant_extensions/chat_context_toolkit/message_history/_history.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from semantic_workbench_assistant.assistant_app import ConversationContext
2424

25-
from assistant_extensions.attachments._attachments import get_attachments
25+
from assistant_extensions.attachments._model import Attachment
2626
from assistant_extensions.attachments._summarizer import LLMConfig, LLMFileSummarizer
2727

2828
from ._message import conversation_message_to_chat_message_param
@@ -132,11 +132,23 @@ class CompositeMessageProtocol(HistoryMessageProtocol, ArchiveMessageProtocol, P
132132
...
133133

134134

135-
def chat_context_toolkit_message_provider_for(
136-
context: ConversationContext,
137-
tool_abbreviations: ToolAbbreviations,
135+
def construct_attachment_summarizer(
138136
service_config: ServiceConfig,
139137
request_config: OpenAIRequestConfig,
138+
) -> LLMFileSummarizer:
139+
return LLMFileSummarizer(
140+
llm_config=LLMConfig(
141+
client_factory=lambda: create_client(service_config),
142+
model=request_config.model,
143+
max_response_tokens=request_config.response_tokens,
144+
)
145+
)
146+
147+
148+
def chat_context_toolkit_message_provider_for(
149+
context: ConversationContext,
150+
attachments: list[Attachment],
151+
tool_abbreviations: ToolAbbreviations = ToolAbbreviations(),
140152
) -> CompositeMessageProvider:
141153
"""
142154
Create a composite message provider for the given workbench conversation context.
@@ -146,9 +158,8 @@ async def provider(after_id: str | None = None) -> Sequence[CompositeMessageProt
146158
history = await _get_history_manager_messages(
147159
context,
148160
tool_abbreviations=tool_abbreviations,
149-
service_config=service_config,
150-
request_config=request_config,
151161
after_id=after_id,
162+
attachments=attachments,
152163
)
153164

154165
return history
@@ -159,8 +170,7 @@ async def provider(after_id: str | None = None) -> Sequence[CompositeMessageProt
159170
async def _get_history_manager_messages(
160171
context: ConversationContext,
161172
tool_abbreviations: ToolAbbreviations,
162-
service_config: ServiceConfig,
163-
request_config: OpenAIRequestConfig,
173+
attachments: list[Attachment],
164174
after_id: str | None = None,
165175
) -> list[HistoryMessageWithAbbreviation]:
166176
"""
@@ -175,21 +185,6 @@ async def _get_history_manager_messages(
175185
batch_size = 100
176186
before_message_id = None
177187

178-
attachments = list(
179-
await get_attachments(
180-
context=context,
181-
include_filenames=None,
182-
exclude_filenames=[],
183-
summarizer=LLMFileSummarizer(
184-
llm_config=LLMConfig(
185-
client_factory=lambda: create_client(service_config),
186-
model=request_config.model,
187-
max_response_tokens=request_config.response_tokens,
188-
)
189-
),
190-
)
191-
)
192-
193188
# each call to get_messages will return a maximum of `batch_size` messages
194189
# so we need to loop until all messages are retrieved
195190
while True:

0 commit comments

Comments
 (0)