Skip to content

Commit 3a074b9

Browse files
authored
Log tool execution start/end time (#304)
1 parent 1fec003 commit 3a074b9

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

src/aviary/env.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import os
99
import random
10+
import time
1011
from abc import ABC, abstractmethod
1112
from collections.abc import Awaitable, Iterator
1213
from copy import deepcopy
@@ -216,6 +217,7 @@ async def exec_tool_calls(
216217
"""
217218

218219
async def _exec_tool_call(tool_call: ToolCall) -> ToolResponseMessage:
220+
start = time.monotonic()
219221
try:
220222
tool = next(
221223
t for t in self.tools if t.info.name == tool_call.function.name
@@ -283,7 +285,14 @@ async def _exec_tool_call(tool_call: ToolCall) -> ToolResponseMessage:
283285
s_content = content.model_dump_json(exclude_none=True, by_alias=True)
284286
else: # Fallback when content is another type, or None
285287
s_content = json.dumps(content)
286-
return ToolResponseMessage.from_call(tool_call, content=s_content)
288+
return ToolResponseMessage.from_call(
289+
tool_call,
290+
content=s_content,
291+
info={
292+
"start_ts": start,
293+
"end_ts": time.monotonic(),
294+
},
295+
)
287296

288297
invalid_responses = []
289298
valid_action = message

src/aviary/tools/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,10 @@ class ToolResponseMessage(Message):
165165
)
166166

167167
@classmethod
168-
def from_call(cls, call: ToolCall, content: str) -> Self:
169-
return cls(content=content, name=call.function.name, tool_call_id=call.id)
168+
def from_call(cls, call: ToolCall, content: str, **kwargs) -> Self:
169+
return cls(
170+
content=content, name=call.function.name, tool_call_id=call.id, **kwargs
171+
)
170172

171173
@classmethod
172174
def from_request(

tests/test_tools.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import json
23
import os
34
import pickle
@@ -689,6 +690,30 @@ def get_todo_list_no_args():
689690
new_messages = await dummy_env.exec_tool_calls(action)
690691
assert new_messages[0].content == "Go for a walk"
691692

693+
@pytest.mark.asyncio
694+
async def test_tool_timing(self) -> None:
695+
sleep_time = 0.1
696+
697+
async def sleep_tool_fn() -> None:
698+
"""Zzz."""
699+
await asyncio.sleep(sleep_time)
700+
701+
tool = Tool.from_function(sleep_tool_fn)
702+
tool_calls = [ToolCall.from_tool(tool) for _ in range(3)]
703+
704+
dummy_env = DummyEnv()
705+
dummy_env.tools = [tool]
706+
707+
responses = await dummy_env.exec_tool_calls(
708+
ToolRequestMessage(tool_calls=tool_calls)
709+
)
710+
711+
for resp in responses:
712+
assert resp.info, "Expected timing info to be present"
713+
assert resp.info["end_ts"] - resp.info["start_ts"] >= sleep_time, (
714+
"Expected non-trivial time elapsed."
715+
)
716+
692717

693718
def test_argref_by_name_basic_usage() -> None:
694719
class MyState:

0 commit comments

Comments
 (0)