Skip to content

Commit 350e569

Browse files
committed
Added usage tracking
1 parent 895f7e2 commit 350e569

File tree

4 files changed

+82
-31
lines changed

4 files changed

+82
-31
lines changed

examples/bedrock_agents.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from llmstudio_core.agents import AgentManagerCore
3-
from llmstudio_core.agents.bedrock.data_models import BedrockCreateAgentRequest, BedrockRunAgentRequest, BedrockToolOutput, BedrockToolCall
4-
from llmstudio_core.agents.data_models import ResultBase, ToolCall, ToolOutput
3+
from llmstudio_core.agents.bedrock.data_models import BedrockCreateAgentRequest, BedrockToolCall
4+
from llmstudio_core.agents.data_models import ResultBase, ToolCall, ToolOutput, RunAgentRequest
55
import boto3
66
import uuid
77

@@ -67,10 +67,9 @@
6767

6868
runs = []
6969

70-
for i in range(1,3):
71-
run_agent_request = BedrockRunAgentRequest(
70+
for i in range(1,2):
71+
run_agent_request = RunAgentRequest(
7272
agent_id = agent.agent_id,
73-
thread_id=f"111{i}",#remove this
7473
alias_id=agent.agent_alias_id,#make this optional
7574
messages=[
7675
{"role": "user", "content": "What is the weather like in Lisbon, PT?"},
@@ -87,15 +86,15 @@
8786
else:
8887
tool_calls : list[BedrockToolCall] = result.messages[-1].required_action.submit_tools_outputs
8988

90-
submit_outputs_request = BedrockRunAgentRequest(
89+
submit_outputs_request = RunAgentRequest(
9190
agent_id=agent.agent_id,
9291
thread_id=result.thread_id,
9392
alias_id=agent.agent_alias_id,
9493
tool_outputs=[]
9594
)
9695

9796
for tool_call in tool_calls:
98-
submit_outputs_request.tool_outputs.append(BedrockToolOutput(tool_call_id=tool_call.id, output="10", action_group=tool_call.action_group, function_name=tool_call.function.name))
97+
submit_outputs_request.tool_outputs.append(ToolOutput(tool_call_id=tool_call.id, output="10", action_group=tool_call.action_group, function_name=tool_call.function.name))
9998

10099
outputs_request = submit_outputs_request.model_dump()
101100
run = bedrock_agent_manager.submit_tool_outputs(submit_outputs_request.model_dump())

libs/core/llmstudio_core/agents/bedrock/data_models.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Awaitable, List, Literal, Optional
1+
from typing import Awaitable, Literal
22

33
from llmstudio_core.agents.data_models import (
44
AgentBase,
55
CreateAgentRequest,
6-
RunAgentRequest,
76
RunBase,
87
Tool,
98
ToolCall,
10-
ToolOutput,
119
)
1210
from pydantic import BaseModel
1311

@@ -67,13 +65,3 @@ def from_tool(cls, tool: Tool) -> "BedrockTool":
6765

6866
class BedrockToolCall(ToolCall):
6967
action_group: str
70-
71-
72-
class BedrockToolOutput(ToolOutput, extra="allow"):
73-
action_group: str
74-
function_name: str
75-
76-
77-
class BedrockRunAgentRequest(RunAgentRequest, extra="allow"):
78-
alias_id: str
79-
tool_outputs: Optional[List[BedrockToolOutput]] = None

libs/core/llmstudio_core/agents/bedrock/manager.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import uuid
34
from concurrent.futures import ThreadPoolExecutor
45
from functools import partial
56

@@ -8,10 +9,8 @@
89
BedrockAgent,
910
BedrockCreateAgentRequest,
1011
BedrockRun,
11-
BedrockRunAgentRequest,
1212
BedrockTool,
1313
BedrockToolCall,
14-
BedrockToolOutput,
1514
)
1615
from llmstudio_core.agents.data_models import (
1716
Attachment,
@@ -20,9 +19,12 @@
2019
Message,
2120
RequiredAction,
2221
ResultBase,
22+
RunAgentRequest,
2323
TextContent,
2424
TextObject,
25+
ToolCall,
2526
ToolCallFunction,
27+
ToolOutput,
2628
)
2729
from llmstudio_core.agents.manager import AgentManager, agent_manager
2830
from llmstudio_core.exceptions import AgentError
@@ -48,7 +50,7 @@ def _validate_create_request(self, request):
4850
return BedrockCreateAgentRequest(**request)
4951

5052
def _validate_run_request(self, request):
51-
return BedrockRunAgentRequest(**request)
53+
return RunAgentRequest(**request)
5254

5355
def _validate_result_request(self, request):
5456
if isinstance(request, BedrockRun):
@@ -206,6 +208,9 @@ def run_agent(self, params: dict = None) -> BedrockRun:
206208
except ValidationError as e:
207209
raise AgentError(str(e))
208210

211+
if not run_request.thread_id:
212+
run_request.thread_id = str(uuid.uuid4())
213+
209214
sessionState = {"files": [], "conversationHistory": {"messages": []}}
210215

211216
if isinstance(run_request.messages, Message):
@@ -266,6 +271,7 @@ def run_agent(self, params: dict = None) -> BedrockRun:
266271
sessionId=run_request.thread_id,
267272
inputText=input_text,
268273
sessionState=sessionState,
274+
enableTrace=True,
269275
),
270276
)
271277

@@ -302,8 +308,59 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
302308

303309
content = []
304310
attachments = []
311+
messages = []
312+
usage = None
305313
event_stream = run.response.get("completion")
306314
for event in event_stream:
315+
if "trace" in event:
316+
trace = event["trace"]["trace"]["orchestrationTrace"]
317+
318+
if "modelInvocationInput" in trace:
319+
invocation_in = trace["modelInvocationInput"]
320+
text = json.loads(invocation_in["text"])
321+
new_messages = [
322+
Message(content=message["content"], role=message["role"])
323+
for message in text["messages"]
324+
]
325+
messages += new_messages
326+
327+
if "modelInvocationOutput" in trace:
328+
invocation_out = trace["modelInvocationOutput"]["rawResponse"][
329+
"content"
330+
]
331+
invocation_out = json.loads(invocation_out)
332+
if "metadata" in invocation_out:
333+
usage = invocation_out["metadata"]["usage"]
334+
elif "usage" in invocation_out:
335+
usage = invocation_out["usage"]
336+
337+
messages = invocation_out["content"]
338+
new_messages = []
339+
for message in messages:
340+
if message["type"] == "text":
341+
new_messages.append(
342+
Message(content=message["text"], role="assistant")
343+
)
344+
345+
elif message["type"] == "tool_use":
346+
tool_name = message["name"]
347+
tool_arguments = str(message["input"])
348+
tool_call_id = message["id"]
349+
350+
tool_call_func = ToolCallFunction(
351+
arguments=tool_arguments, name=tool_name
352+
)
353+
tool_call = ToolCall(
354+
id=tool_call_id,
355+
function=tool_call_func,
356+
type=message["type"],
357+
)
358+
required_action = RequiredAction(
359+
submit_tools_outputs=[tool_call]
360+
)
361+
new_message = Message(required_action=required_action)
362+
new_messages.append(new_message)
363+
307364
if "chunk" in event:
308365
chunk = event["chunk"]
309366
if "bytes" in chunk:
@@ -363,6 +420,7 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
363420
function=tool_call_function,
364421
type=invocation_type,
365422
action_group=action_group,
423+
usage=usage,
366424
)
367425

368426
required_action.submit_tools_outputs.append(tool_call)
@@ -375,9 +433,10 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
375433
)
376434
],
377435
thread_id=run.thread_id,
436+
usage=usage,
378437
)
379438

380-
messages = [
439+
messages = new_messages + [
381440
Message(
382441
thread_id=run.thread_id,
383442
role="assistant",
@@ -386,7 +445,7 @@ async def aretrieve_result(self, run: BedrockRun) -> ResultBase:
386445
)
387446
]
388447

389-
return ResultBase(messages=messages, thread_id=run.thread_id)
448+
return ResultBase(messages=messages, thread_id=run.thread_id, usage=usage)
390449

391450
def submit_tool_outputs(self, params: dict = None) -> ResultBase:
392451
try:
@@ -397,7 +456,7 @@ def submit_tool_outputs(self, params: dict = None) -> ResultBase:
397456
if not run_request.tool_outputs:
398457
raise AgentError("No tool outputs found")
399458

400-
tool_outputs: list[BedrockToolOutput] = run_request.tool_outputs
459+
tool_outputs: list[ToolOutput] = run_request.tool_outputs
401460

402461
invocation_results = [
403462
{
@@ -424,6 +483,7 @@ def submit_tool_outputs(self, params: dict = None) -> ResultBase:
424483
agentAliasId=run_request.alias_id,
425484
sessionId=run_request.thread_id,
426485
sessionState=sessionState,
486+
enableTrace=True,
427487
),
428488
)
429489

libs/core/llmstudio_core/agents/data_models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ class RequiredAction(BaseModel):
105105
type: Literal["submit_tool_outputs"] = "submit_tool_outputs"
106106

107107

108-
class ToolOutput(BaseModel, extra="allow"):
109-
tool_call_id: str
110-
output: str
108+
class ToolOutput(BaseModel):
109+
tool_call_id: Optional[str]
110+
output: Optional[str]
111+
action_group: Optional[str]
112+
function_name: Optional[str]
111113

112114

113115
class Message(BaseModel):
@@ -147,6 +149,7 @@ class RunBase(BaseModel):
147149
class ResultBase(BaseModel):
148150
thread_id: str
149151
messages: List[Message]
152+
usage: Optional[dict] = None
150153

151154

152155
class CreateAgentRequest(BaseModel):
@@ -157,8 +160,9 @@ class CreateAgentRequest(BaseModel):
157160
name: Optional[str] = None
158161

159162

160-
class RunAgentRequest(BaseModel, extra="allow"):
161-
thread_id: Optional[str] = None
163+
class RunAgentRequest(BaseModel):
162164
agent_id: str
165+
alias_id: Optional[str]
166+
thread_id: Optional[str] = None
163167
messages: Optional[List[Message]] = None
164168
tool_outputs: Optional[List[ToolOutput]] = None

0 commit comments

Comments
 (0)