@@ -35,7 +35,6 @@ def __init__(
35
35
model : str = None ,
36
36
):
37
37
self .document_service = document_service
38
- self .sources = {}
39
38
# Load settings
40
39
self .settings = get_settings ()
41
40
self .model = model or self .settings .AGENT_MODEL
@@ -100,20 +99,22 @@ def __init__(
100
99
when citing different sources. Use markdown formatting for text content to improve readability.
101
100
""" .strip ()
102
101
103
- async def _execute_tool (self , name : str , args : dict , auth : AuthContext ):
102
+ async def _execute_tool (self , name : str , args : dict , auth : AuthContext , source_map : dict ):
104
103
"""Dispatch tool calls, injecting document_service and auth."""
105
104
match name :
106
105
case "retrieve_chunks" :
107
- content , sources = await retrieve_chunks (document_service = self .document_service , auth = auth , ** args )
108
- self .sources .update (sources )
106
+ content , found_sources = await retrieve_chunks (
107
+ document_service = self .document_service , auth = auth , ** args
108
+ )
109
+ source_map .update (found_sources )
109
110
return content
110
111
case "retrieve_document" :
111
112
result = await retrieve_document (document_service = self .document_service , auth = auth , ** args )
112
113
# Add document as a source if it's a successful retrieval
113
114
if isinstance (result , str ) and not result .startswith ("Document" ) and not result .startswith ("Error" ):
114
115
doc_id = args .get ("document_id" , "unknown" )
115
116
source_id = f"doc{ doc_id } -full"
116
- self . sources [source_id ] = {
117
+ source_map [source_id ] = {
117
118
"document_id" : doc_id ,
118
119
"document_name" : f"Full Document { doc_id } " ,
119
120
"chunk_number" : "full" ,
@@ -126,7 +127,7 @@ async def _execute_tool(self, name: str, args: dict, auth: AuthContext):
126
127
doc_id = args .get ("document_id" )
127
128
analysis_type = args .get ("analysis_type" , "analysis" )
128
129
source_id = f"doc{ doc_id } -{ analysis_type } "
129
- self . sources [source_id ] = {
130
+ source_map [source_id ] = {
130
131
"document_id" : doc_id ,
131
132
"document_name" : f"Document { doc_id } ({ analysis_type } )" ,
132
133
"analysis_type" : analysis_type ,
@@ -148,6 +149,8 @@ async def _execute_tool(self, name: str, args: dict, auth: AuthContext):
148
149
149
150
async def run (self , query : str , auth : AuthContext ) -> str :
150
151
"""Synchronously run the agent and return the final answer."""
152
+ # Per-run state to avoid cross-request leakage
153
+ source_map : dict = {}
151
154
messages = [
152
155
{"role" : "system" , "content" : self .system_prompt },
153
156
{"role" : "user" , "content" : query },
@@ -212,8 +215,8 @@ async def run(self, query: str, auth: AuthContext) -> str:
212
215
}
213
216
if "caption" in item and item ["type" ] == "image" :
214
217
display_obj ["caption" ] = item ["caption" ]
215
- if item ["type" ] == "image" :
216
- display_obj ["content" ] = self . sources [item ["source" ]]["content" ]
218
+ if item ["type" ] == "image" and item . get ( "source" ) in source_map :
219
+ display_obj ["content" ] = source_map [item ["source" ]]["content" ]
217
220
display_objects .append (display_obj )
218
221
elif (
219
222
isinstance (parsed_content , dict )
@@ -228,8 +231,8 @@ async def run(self, query: str, auth: AuthContext) -> str:
228
231
}
229
232
if "caption" in parsed_content and parsed_content ["type" ] == "image" :
230
233
display_obj ["caption" ] = parsed_content ["caption" ]
231
- if item [ "type" ] == "image" :
232
- display_obj ["content" ] = self . sources [ item ["source" ]]["content" ]
234
+ if parsed_content . get ( "type" ) == "image" and parsed_content . get ( "source" ) in source_map :
235
+ display_obj ["content" ] = source_map [ parsed_content ["source" ]]["content" ]
233
236
display_objects .append (display_obj )
234
237
235
238
# If no display objects were created, treat the entire content as text
@@ -260,7 +263,7 @@ async def run(self, query: str, auth: AuthContext) -> str:
260
263
"sourceId" : source_id ,
261
264
"documentName" : f"Document { doc_id } " ,
262
265
"documentId" : doc_id ,
263
- "content" : self . sources . get (source_id , {"content" : "" })[ "content" ] ,
266
+ "content" : source_map . get (source_id , {"content" : "" }). get ( "content" , "" ) ,
264
267
}
265
268
)
266
269
else :
@@ -269,7 +272,7 @@ async def run(self, query: str, auth: AuthContext) -> str:
269
272
"sourceId" : source_id ,
270
273
"documentName" : "Referenced Source" ,
271
274
"documentId" : "unknown" ,
272
- "content" : self . sources . get (source_id , {"content" : "" })[ "content" ] ,
275
+ "content" : source_map . get (source_id , {"content" : "" }). get ( "content" , "" ) ,
273
276
}
274
277
)
275
278
@@ -285,7 +288,7 @@ async def run(self, query: str, auth: AuthContext) -> str:
285
288
)
286
289
287
290
# Add sources from document chunks used during the session
288
- for source_id , source_info in self . sources .items ():
291
+ for source_id , source_info in source_map .items ():
289
292
if source_id not in seen_source_ids :
290
293
sources .append (
291
294
{
@@ -314,7 +317,7 @@ async def run(self, query: str, auth: AuthContext) -> str:
314
317
# messages.append({'role': 'assistant', 'content': msg.content})
315
318
messages .append (msg .to_dict (exclude_none = True ))
316
319
logger .info (f"Executing tool: { name } " )
317
- result = await self ._execute_tool (name , args , auth )
320
+ result = await self ._execute_tool (name , args , auth , source_map )
318
321
logger .info (f"Tool execution result: { result } " )
319
322
320
323
# Add tool call and result to history
0 commit comments