Skip to content

Commit 8a0135b

Browse files
first pass at separate converters for legacy vs fdp (Azure#41017)
* first pass at separate converters for legacy vs fdp * clean up imports * more cleanup * combine into one class * update for change in function tool call details object * update test notebook * don't import the whole module * remove factory, keep public objects the same * update notebook * address PR comments --------- Co-authored-by: spon <[email protected]>
1 parent f32244d commit 8a0135b

File tree

5 files changed

+378
-126
lines changed

5 files changed

+378
-126
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_converters/_ai_services.py

Lines changed: 159 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
import json
2+
from abc import abstractmethod
23
from concurrent.futures import ThreadPoolExecutor, as_completed
34

5+
from azure.ai.projects import __version__ as projects_version
46
from azure.ai.projects import AIProjectClient
5-
from azure.ai.projects.models import (
6-
ThreadRun,
7-
RunStep,
8-
RunStepToolCallDetails,
9-
FunctionDefinition,
10-
ListSortOrder,
11-
)
127

138
from typing import List, Union
149

1510
from azure.ai.evaluation._common._experimental import experimental
11+
from packaging.version import Version
1612

1713
# Constants.
1814
from ._models import _USER, _AGENT, _TOOL, _TOOL_CALL, _TOOL_CALLS, _FUNCTION, _BUILT_IN_DESCRIPTIONS, _BUILT_IN_PARAMS
@@ -26,21 +22,20 @@
2622
# Utilities.
2723
from ._models import break_tool_call_into_messages, convert_message
2824

29-
# Maximum items to fetch in a single AI Services API call (imposed by the service).
30-
_AI_SERVICES_API_MAX_LIMIT = 100
31-
32-
# Maximum number of workers allowed to make API calls at the same time.
33-
_MAX_WORKERS = 10
3425

3526
@experimental
3627
class AIAgentConverter:
3728
"""
38-
A converter for AI agent data.
29+
A converter for AI agent data. Data retrieval classes handle getting agent data depending on
30+
agent version.
3931
4032
:param project_client: The AI project client used for API interactions.
4133
:type project_client: AIProjectClient
4234
"""
4335

36+
# Maximum number of workers allowed to make API calls at the same time.
37+
_MAX_WORKERS = 10
38+
4439
def __init__(self, project_client: AIProjectClient):
4540
"""
4641
Initializes the AIAgentConverter with the given AI project client.
@@ -49,30 +44,16 @@ def __init__(self, project_client: AIProjectClient):
4944
:type project_client: AIProjectClient
5045
"""
5146
self.project_client = project_client
47+
self._data_retriever = AIAgentConverter._get_data_retriever(project_client=project_client)
5248

53-
def _list_messages_chronological(self, thread_id: str):
54-
"""
55-
Lists messages in chronological order for a given thread.
56-
57-
:param thread_id: The ID of the thread.
58-
:type thread_id: str
59-
:return: A list of messages in chronological order.
60-
"""
61-
to_return = []
62-
63-
has_more = True
64-
after = None
65-
while has_more:
66-
messages = self.project_client.agents.list_messages(
67-
thread_id=thread_id, limit=_AI_SERVICES_API_MAX_LIMIT, order=ListSortOrder.ASCENDING, after=after
68-
)
69-
has_more = messages.has_more
70-
after = messages.last_id
71-
if messages.data:
72-
# We need to add the messages to the accumulator.
73-
to_return.extend(messages.data)
74-
75-
return to_return
49+
@staticmethod
50+
def _get_data_retriever(project_client: AIProjectClient):
51+
if project_client is None:
52+
return None
53+
if Version(projects_version) > Version("1.0.0b10"):
54+
return FDPAgentDataRetriever(project_client=project_client)
55+
else:
56+
return LegacyAgentDataRetriever(project_client=project_client)
7657

7758
def _list_tool_calls_chronological(self, thread_id: str, run_id: str) -> List[ToolCall]:
7859
"""
@@ -87,29 +68,14 @@ def _list_tool_calls_chronological(self, thread_id: str, run_id: str) -> List[To
8768
"""
8869
# This is the other API request that we need to make to AI service, such that we can get the details about
8970
# the tool calls and results. Since the list is given in reverse chronological order, we need to reverse it.
90-
run_steps_chronological: List[RunStep] = []
91-
has_more = True
92-
after = None
93-
while has_more:
94-
run_steps = self.project_client.agents.list_run_steps(
95-
thread_id=thread_id,
96-
run_id=run_id,
97-
limit=_AI_SERVICES_API_MAX_LIMIT,
98-
order=ListSortOrder.ASCENDING,
99-
after=after,
100-
)
101-
has_more = run_steps.has_more
102-
after = run_steps.last_id
103-
if run_steps.data:
104-
# We need to add the run steps to the accumulator.
105-
run_steps_chronological.extend(run_steps.data)
71+
run_steps_chronological = self._data_retriever._list_run_steps_chronological(thread_id=thread_id, run_id=run_id)
10672

10773
# Let's accumulate the function calls in chronological order. Function calls
10874
tool_calls_chronological: List[ToolCall] = []
10975
for run_step_chronological in run_steps_chronological:
11076
if run_step_chronological.type != _TOOL_CALLS:
11177
continue
112-
step_details: RunStepToolCallDetails = run_step_chronological.step_details
78+
step_details: object = run_step_chronological.step_details
11379
if step_details.type != _TOOL_CALLS:
11480
continue
11581
if len(step_details.tool_calls) < 1:
@@ -126,26 +92,13 @@ def _list_tool_calls_chronological(self, thread_id: str, run_id: str) -> List[To
12692

12793
return tool_calls_chronological
12894

129-
def _list_run_ids_chronological(self, thread_id: str) -> List[str]:
130-
"""
131-
Lists run IDs in chronological order for a given thread.
132-
133-
:param thread_id: The ID of the thread.
134-
:type thread_id: str
135-
:return: A list of run IDs in chronological order.
136-
:rtype: List[str]
137-
"""
138-
runs = self.project_client.agents.list_runs(thread_id=thread_id, order=ListSortOrder.ASCENDING)
139-
run_ids = [run["id"] for run in runs["data"]]
140-
return run_ids
141-
14295
@staticmethod
143-
def _extract_function_tool_definitions(thread_run: ThreadRun) -> List[ToolDefinition]:
96+
def _extract_function_tool_definitions(thread_run: object) -> List[ToolDefinition]:
14497
"""
14598
Extracts tool definitions from a thread run.
14699
147100
:param thread_run: The thread run containing tool definitions.
148-
:type thread_run: ThreadRun
101+
:type thread_run: object
149102
:return: A list of tool definitions extracted from the thread run.
150103
:rtype: List[ToolDefinition]
151104
"""
@@ -368,12 +321,12 @@ def _retrieve_tool_calls_up_to_including_run_id(
368321
# We set the include_run_id to False, since we don't want to include the current run's tool calls, which
369322
# are already included in the previous step.
370323
run_ids_up_to_run_id = AIAgentConverter._filter_run_ids_up_to_run_id(
371-
self._list_run_ids_chronological(thread_id), run_id, include_run_id=False
324+
self._data_retriever._list_run_ids_chronological(thread_id), run_id, include_run_id=False
372325
)
373326

374327
# Since each _list_tool_calls_chronological call is expensive, we can use a thread pool to speed
375328
# up the process by parallelizing the AI Services API requests.
376-
with ThreadPoolExecutor(max_workers=_MAX_WORKERS) as executor:
329+
with ThreadPoolExecutor(max_workers=self._MAX_WORKERS) as executor:
377330
futures = {
378331
executor.submit(self._fetch_tool_calls, thread_id, run_id): run_id
379332
for run_id in run_ids_up_to_run_id
@@ -399,7 +352,7 @@ def _retrieve_all_tool_calls(self, thread_id: str, run_ids: List[str]) -> List[M
399352
"""
400353
to_return: List[Message] = []
401354

402-
with ThreadPoolExecutor(max_workers=_MAX_WORKERS) as executor:
355+
with ThreadPoolExecutor(max_workers=self._MAX_WORKERS) as executor:
403356
futures = {executor.submit(self._fetch_tool_calls, thread_id, run_id): run_id for run_id in run_ids}
404357
for future in as_completed(futures):
405358
to_return.extend(future.result())
@@ -460,10 +413,10 @@ def convert(self, thread_id: str, run_id: str, exclude_tool_calls_previous_runs:
460413
:rtype: dict
461414
"""
462415
# Make the API call once and reuse the result.
463-
thread_run: ThreadRun = self.project_client.agents.get_run(thread_id=thread_id, run_id=run_id)
416+
thread_run: object = self._data_retriever._get_run(thread_id=thread_id, run_id=run_id)
464417

465418
# Walk through the "user-facing" conversation history and start adding messages.
466-
chronological_conversation = self._list_messages_chronological(thread_id)
419+
chronological_conversation = self._data_retriever._list_messages_chronological(thread_id)
467420

468421
# Since this is Xth run of out possibly N runs, we are only interested is messages that are before the run X.
469422
chrono_until_run_id = AIAgentConverter._filter_messages_up_to_run_id(chronological_conversation, run_id)
@@ -519,14 +472,14 @@ def _prepare_single_thread_evaluation_data(self, thread_id: str, filename: str =
519472
list_of_run_evaluations: List[dict] = []
520473

521474
# These are all the run IDs.
522-
run_ids = self._list_run_ids_chronological(thread_id)
475+
run_ids = self._data_retriever._list_run_ids_chronological(thread_id)
523476

524477
# If there were no messages in the thread, we can return an empty list.
525478
if len(run_ids) < 1:
526479
return list_of_run_evaluations
527480

528481
# These are all the messages.
529-
chronological_conversation = self._list_messages_chronological(thread_id)
482+
chronological_conversation = self._data_retriever._list_messages_chronological(thread_id)
530483

531484
# If there are no messages in the thread, we can return an empty list.
532485
if len(chronological_conversation) < 1:
@@ -536,7 +489,7 @@ def _prepare_single_thread_evaluation_data(self, thread_id: str, filename: str =
536489
all_sorted_tool_calls = AIAgentConverter._sort_messages(self._retrieve_all_tool_calls(thread_id, run_ids))
537490

538491
# The last run should have all the tool definitions.
539-
thread_run = self.project_client.agents.get_run(thread_id=thread_id, run_id=run_ids[-1])
492+
thread_run = self._data_retriever._get_run(thread_id=thread_id, run_id=run_ids[-1])
540493
instructions = thread_run.instructions
541494

542495
# So then we can get the tool definitions.
@@ -609,7 +562,7 @@ def prepare_evaluation_data(self, thread_ids=Union[str, List[str]], filename: st
609562
return self._prepare_single_thread_evaluation_data(thread_id=thread_ids, filename=filename)
610563

611564
evaluations = []
612-
with ThreadPoolExecutor(max_workers=_MAX_WORKERS) as executor:
565+
with ThreadPoolExecutor(max_workers=self._MAX_WORKERS) as executor:
613566
# We override the filename, because we don't want to write the file for each thread, having to handle
614567
# threading issues and file being opened from multiple threads, instead, we just want to write it once
615568
# at the end.
@@ -764,3 +717,132 @@ def _convert_from_file(filename: str, run_id: str) -> dict:
764717
data = json.load(file)
765718

766719
return AIAgentConverter._convert_from_conversation(data, run_id)
720+
721+
@experimental
722+
class AIAgentDataRetriever:
723+
# Maximum items to fetch in a single AI Services API call (imposed by the service).
724+
_AI_SERVICES_API_MAX_LIMIT = 100
725+
726+
def __init__(self, project_client: AIProjectClient):
727+
"""
728+
Initializes the AIAgentDataRetriever with the given AI project client.
729+
730+
:param project_client: The AI project client used for API interactions.
731+
:type project_client: AIProjectClient
732+
"""
733+
self.project_client = project_client
734+
735+
@abstractmethod
736+
def _get_run(self, thread_id: str, run_id: str):
737+
pass
738+
739+
@abstractmethod
740+
def _list_messages_chronological(self, thread_id: str):
741+
pass
742+
743+
@abstractmethod
744+
def _list_run_steps_chronological(self, thread_id: str, run_id: str):
745+
pass
746+
747+
@abstractmethod
748+
def _list_run_ids_chronological(self, thread_id: str) -> List[str]:
749+
pass
750+
751+
@experimental
752+
class LegacyAgentDataRetriever(AIAgentDataRetriever):
753+
754+
def __init__(self, **kwargs):
755+
super(LegacyAgentDataRetriever, self).__init__(**kwargs)
756+
757+
def _list_messages_chronological(self, thread_id: str):
758+
"""
759+
Lists messages in chronological order for a given thread.
760+
761+
:param thread_id: The ID of the thread.
762+
:type thread_id: str
763+
:return: A list of messages in chronological order.
764+
"""
765+
to_return = []
766+
767+
has_more = True
768+
after = None
769+
while has_more:
770+
messages = self.project_client.agents.list_messages(
771+
thread_id=thread_id, limit=self._AI_SERVICES_API_MAX_LIMIT, order="asc", after=after)
772+
has_more = messages.has_more
773+
after = messages.last_id
774+
if messages.data:
775+
# We need to add the messages to the accumulator.
776+
to_return.extend(messages.data)
777+
778+
return to_return
779+
780+
def _list_run_steps_chronological(self, thread_id: str, run_id: str):
781+
run_steps_chronological: List[object] = []
782+
has_more = True
783+
after = None
784+
while has_more:
785+
run_steps = self.project_client.agents.list_run_steps(
786+
thread_id=thread_id,
787+
run_id=run_id,
788+
limit=self._AI_SERVICES_API_MAX_LIMIT,
789+
order="asc",
790+
after=after,
791+
)
792+
has_more = run_steps.has_more
793+
after = run_steps.last_id
794+
if run_steps.data:
795+
# We need to add the run steps to the accumulator.
796+
run_steps_chronological.extend(run_steps.data)
797+
return run_steps_chronological
798+
799+
def _list_run_ids_chronological(self, thread_id: str) -> List[str]:
800+
"""
801+
Lists run IDs in chronological order for a given thread.
802+
803+
:param thread_id: The ID of the thread.
804+
:type thread_id: str
805+
:return: A list of run IDs in chronological order.
806+
:rtype: List[str]
807+
"""
808+
runs = self.project_client.agents.list_runs(thread_id=thread_id, order="asc")
809+
run_ids = [run["id"] for run in runs["data"]]
810+
return run_ids
811+
812+
def _get_run(self, thread_id: str, run_id: str):
813+
return self.project_client.agents.get_run(thread_id=thread_id, run_id=run_id)
814+
815+
@experimental
816+
class FDPAgentDataRetriever(AIAgentDataRetriever):
817+
818+
def __init__(self, **kwargs):
819+
super(FDPAgentDataRetriever, self).__init__(**kwargs)
820+
821+
def _list_messages_chronological(self, thread_id: str):
822+
"""
823+
Lists messages in chronological order for a given thread.
824+
825+
:param thread_id: The ID of the thread.
826+
:type thread_id: str
827+
:return: A list of messages in chronological order.
828+
"""
829+
message_iter = self.project_client.agents.messages.list(
830+
thread_id=thread_id, limit=self._AI_SERVICES_API_MAX_LIMIT, order="asc"
831+
)
832+
return [message for message in message_iter]
833+
834+
def _list_run_steps_chronological(self, thread_id: str, run_id: str):
835+
836+
return self.project_client.agents.run_steps.list(
837+
thread_id=thread_id,
838+
run_id=run_id,
839+
limit=self._AI_SERVICES_API_MAX_LIMIT,
840+
order="asc"
841+
)
842+
843+
def _list_run_ids_chronological(self, thread_id: str) -> List[str]:
844+
runs = self.project_client.agents.runs.list(thread_id=thread_id, order="asc")
845+
return [run.id for run in runs]
846+
847+
def _get_run(self, thread_id: str, run_id: str):
848+
return self.project_client.agents.runs.get(thread_id=thread_id, run_id=run_id)

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_converters/_models.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33

44
from pydantic import BaseModel
55

6-
from azure.ai.projects.models import RunStepFunctionToolCall
7-
86
from typing import List, Optional, Union
97

8+
# Models moved in a later version of agents SDK, so try a few different locations
9+
try:
10+
from azure.ai.projects.models import RunStepFunctionToolCall
11+
except ImportError:
12+
pass
13+
try:
14+
from azure.ai.agents.models import RunStepFunctionToolCall
15+
except ImportError:
16+
pass
17+
1018
# Message roles constants.
1119
_SYSTEM = "system"
1220
_USER = "user"
@@ -269,7 +277,7 @@ def break_tool_call_into_messages(tool_call: ToolCall, run_id: str) -> List[Mess
269277
messages.append(AssistantMessage(run_id=run_id, content=[to_dict(content_tool_call)], createdAt=tool_call.created))
270278

271279
if hasattr(tool_call.details, _FUNCTION):
272-
output = safe_loads(tool_call.details.function.output)
280+
output = safe_loads(tool_call.details.function["output"])
273281
else:
274282
try:
275283
# Some built-ins may have output, others may not

0 commit comments

Comments
 (0)