11import asyncio
22import json
3+ import uuid
34from concurrent .futures import ThreadPoolExecutor
45from functools import partial
56
89 BedrockAgent ,
910 BedrockCreateAgentRequest ,
1011 BedrockRun ,
11- BedrockRunAgentRequest ,
1212 BedrockTool ,
1313 BedrockToolCall ,
14- BedrockToolOutput ,
1514)
1615from llmstudio_core .agents .data_models import (
1716 Attachment ,
2019 Message ,
2120 RequiredAction ,
2221 ResultBase ,
22+ RunAgentRequest ,
2323 TextContent ,
2424 TextObject ,
25+ ToolCall ,
2526 ToolCallFunction ,
27+ ToolOutput ,
2628)
2729from llmstudio_core .agents .manager import AgentManager , agent_manager
2830from llmstudio_core .exceptions import AgentError
@@ -48,7 +50,7 @@ def _validate_create_request(self, request):
4850 return BedrockCreateAgentRequest (** request )
4951
5052 def _validate_run_request (self , request ):
51- return BedrockRunAgentRequest (** request )
53+ return RunAgentRequest (** request )
5254
5355 def _validate_result_request (self , request ):
5456 if isinstance (request , BedrockRun ):
@@ -206,6 +208,9 @@ def run_agent(self, params: dict = None) -> BedrockRun:
206208 except ValidationError as e :
207209 raise AgentError (str (e ))
208210
211+ if not run_request .thread_id :
212+ run_request .thread_id = str (uuid .uuid4 ())
213+
209214 sessionState = {"files" : [], "conversationHistory" : {"messages" : []}}
210215
211216 if isinstance (run_request .messages , Message ):
@@ -266,6 +271,7 @@ def run_agent(self, params: dict = None) -> BedrockRun:
266271 sessionId = run_request .thread_id ,
267272 inputText = input_text ,
268273 sessionState = sessionState ,
274+ enableTrace = True ,
269275 ),
270276 )
271277
@@ -302,8 +308,59 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
302308
303309 content = []
304310 attachments = []
311+ messages = []
312+ usage = None
305313 event_stream = run .response .get ("completion" )
306314 for event in event_stream :
315+ if "trace" in event :
316+ trace = event ["trace" ]["trace" ]["orchestrationTrace" ]
317+
318+ if "modelInvocationInput" in trace :
319+ invocation_in = trace ["modelInvocationInput" ]
320+ text = json .loads (invocation_in ["text" ])
321+ new_messages = [
322+ Message (content = message ["content" ], role = message ["role" ])
323+ for message in text ["messages" ]
324+ ]
325+ messages += new_messages
326+
327+ if "modelInvocationOutput" in trace :
328+ invocation_out = trace ["modelInvocationOutput" ]["rawResponse" ][
329+ "content"
330+ ]
331+ invocation_out = json .loads (invocation_out )
332+ if "metadata" in invocation_out :
333+ usage = invocation_out ["metadata" ]["usage" ]
334+ elif "usage" in invocation_out :
335+ usage = invocation_out ["usage" ]
336+
337+ messages = invocation_out ["content" ]
338+ new_messages = []
339+ for message in messages :
340+ if message ["type" ] == "text" :
341+ new_messages .append (
342+ Message (content = message ["text" ], role = "assistant" )
343+ )
344+
345+ elif message ["type" ] == "tool_use" :
346+ tool_name = message ["name" ]
347+ tool_arguments = str (message ["input" ])
348+ tool_call_id = message ["id" ]
349+
350+ tool_call_func = ToolCallFunction (
351+ arguments = tool_arguments , name = tool_name
352+ )
353+ tool_call = ToolCall (
354+ id = tool_call_id ,
355+ function = tool_call_func ,
356+ type = message ["type" ],
357+ )
358+ required_action = RequiredAction (
359+ submit_tools_outputs = [tool_call ]
360+ )
361+ new_message = Message (required_action = required_action )
362+ new_messages .append (new_message )
363+
307364 if "chunk" in event :
308365 chunk = event ["chunk" ]
309366 if "bytes" in chunk :
@@ -363,6 +420,7 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
363420 function = tool_call_function ,
364421 type = invocation_type ,
365422 action_group = action_group ,
423+ usage = usage ,
366424 )
367425
368426 required_action .submit_tools_outputs .append (tool_call )
@@ -375,9 +433,10 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
375433 )
376434 ],
377435 thread_id = run .thread_id ,
436+ usage = usage ,
378437 )
379438
380- messages = [
439+ messages = new_messages + [
381440 Message (
382441 thread_id = run .thread_id ,
383442 role = "assistant" ,
@@ -386,7 +445,7 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
386445 )
387446 ]
388447
389- return ResultBase (messages = messages , thread_id = run .thread_id )
448+ return ResultBase (messages = messages , thread_id = run .thread_id , usage = usage )
390449
391450 def submit_tool_outputs (self , params : dict = None ) -> ResultBase :
392451 try :
@@ -397,7 +456,7 @@ def submit_tool_outputs(self, params: dict = None) -> ResultBase:
397456 if not run_request .tool_outputs :
398457 raise AgentError ("No tool outputs found" )
399458
400- tool_outputs : list [BedrockToolOutput ] = run_request .tool_outputs
459+ tool_outputs : list [ToolOutput ] = run_request .tool_outputs
401460
402461 invocation_results = [
403462 {
@@ -424,6 +483,7 @@ def submit_tool_outputs(self, params: dict = None) -> ResultBase:
424483 agentAliasId = run_request .alias_id ,
425484 sessionId = run_request .thread_id ,
426485 sessionState = sessionState ,
486+ enableTrace = True ,
427487 ),
428488 )
429489
0 commit comments