Skip to content

Commit f1358f2

Browse files
authored
Run black tool on Agents (#42851)
1 parent 5e540a4 commit f1358f2

File tree

8 files changed

+353
-269
lines changed

8 files changed

+353
-269
lines changed

sdk/ai/azure-ai-agents/azure/ai/agents/telemetry/_ai_agents_instrumentor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -458,19 +458,15 @@ def _process_tool_calls(self, step: RunStep) -> List[Dict[str, Any]]:
458458
t.type: t.bing_grounding,
459459
}
460460
elif isinstance(t, RunStepOpenAPIToolCall):
461-
tool_call = {
462-
"id": t.id,
463-
"type": t.type,
464-
'function': t.as_dict().get('function', {})
465-
}
461+
tool_call = {"id": t.id, "type": t.type, "function": t.as_dict().get("function", {})}
466462
elif isinstance(t, RunStepMcpToolCall):
467463
tool_call = {
468464
"id": t.id,
469465
"type": t.type,
470466
"arguments": t.arguments,
471467
"name": t.name,
472468
"output": t.output,
473-
"server_label": t.server_label or ""
469+
"server_label": t.server_label or "",
474470
}
475471
else:
476472
# Works for Deep research
@@ -589,7 +585,7 @@ def _add_message_event(
589585
) -> None:
590586
# TODO document new fields
591587

592-
event_body: dict[str, Any]= {}
588+
event_body: dict[str, Any] = {}
593589
if _trace_agents_content:
594590
if isinstance(content, List):
595591
for block in content:
@@ -2101,7 +2097,9 @@ def on_thread_message(self, message: "ThreadMessage") -> None: # type: ignore[f
21012097
# See work item 4636616 and 4636299 for details.
21022098
# When the work item is resolved, change this code back to:
21032099
# if message.status in {MessageStatus.COMPLETED, MessageStatus.INCOMPLETE}
2104-
if message.status in {MessageStatus.COMPLETED, MessageStatus.INCOMPLETE} or (message.status == MessageStatus.IN_PROGRESS and message.content):
2100+
if message.status in {MessageStatus.COMPLETED, MessageStatus.INCOMPLETE} or (
2101+
message.status == MessageStatus.IN_PROGRESS and message.content
2102+
):
21052103
self.last_message = message
21062104

21072105
return retval # type: ignore
@@ -2242,7 +2240,9 @@ async def on_thread_message(self, message: "ThreadMessage") -> None: # type: ig
22422240
# See work item 4636616 and 4636299 for details.
22432241
# When the work item is resolved, change this code back to:
22442242
# if message.status in {MessageStatus.COMPLETED, MessageStatus.INCOMPLETE}
2245-
if message.status in {MessageStatus.COMPLETED, MessageStatus.INCOMPLETE} or (message.status == MessageStatus.IN_PROGRESS and message.content):
2243+
if message.status in {MessageStatus.COMPLETED, MessageStatus.INCOMPLETE} or (
2244+
message.status == MessageStatus.IN_PROGRESS and message.content
2245+
):
22462246
self.last_message = message
22472247

22482248
return retval # type: ignore

sdk/ai/azure-ai-agents/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,4 @@
7777
"typing-extensions>=4.6.0",
7878
],
7979
python_requires=">=3.9",
80-
)
80+
)

sdk/ai/azure-ai-agents/tests/gen_ai_trace_verifier.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def check_span_attributes(self, span, attributes):
4646
# Check if the attribute value matches the provided value
4747
if attribute_value == "+":
4848
if not isinstance(span.attributes[attribute_name], numbers.Number):
49-
raise AssertionError("Attribute value " + str(span.attributes[attribute_name]) + " is not a number")
49+
raise AssertionError(
50+
"Attribute value " + str(span.attributes[attribute_name]) + " is not a number"
51+
)
5052
if span.attributes[attribute_name] < 0:
5153
raise AssertionError("Attribute value " + str(span.attributes[attribute_name]) + " is negative")
5254
elif attribute_value != "" and span.attributes[attribute_name] != attribute_value:
@@ -88,7 +90,9 @@ def check_decorator_span_attributes(self, span: Span, attributes: List[tuple]) -
8890
elif isinstance(attribute_value, dict):
8991
# Check if both are dictionaries and compare them
9092
if not isinstance(span_value, dict) or span_value != attribute_value:
91-
raise AssertionError("Attribute value dict " + str(span_value) + " does not match with " + str(attribute_value))
93+
raise AssertionError(
94+
"Attribute value dict " + str(span_value) + " does not match with " + str(attribute_value)
95+
)
9296
else:
9397
# Check if the attribute value matches the provided value
9498
if attribute_value == "+":
@@ -97,7 +101,9 @@ def check_decorator_span_attributes(self, span: Span, attributes: List[tuple]) -
97101
if span_value < 0:
98102
raise AssertionError("Attribute value " + str(span_value) + " is negative")
99103
elif attribute_value != "" and span_value != attribute_value:
100-
raise AssertionError("Attribute value " + str(span_value) + " does not match with " + str(attribute_value))
104+
raise AssertionError(
105+
"Attribute value " + str(span_value) + " does not match with " + str(attribute_value)
106+
)
101107
# Check if the attribute value in the span is not empty when the provided value is ""
102108
elif attribute_value == "" and not span_value:
103109
raise AssertionError("Expected non-empty attribute value")
@@ -127,7 +133,9 @@ def check_event_attributes(self, expected_dict, actual_dict):
127133
actual_val = json.dumps(actual_dict)
128134
else:
129135
actual_val = actual_dict
130-
raise AssertionError(f"check_event_attributes: keys do not match: {set(expected_dict.keys())} != {set(actual_dict.keys())}. The actual dictionaries: {expected_val} != {actual_val}")
136+
raise AssertionError(
137+
f"check_event_attributes: keys do not match: {set(expected_dict.keys())} != {set(actual_dict.keys())}. The actual dictionaries: {expected_val} != {actual_val}"
138+
)
131139
for key, expected_val in expected_dict.items():
132140
if key not in actual_dict:
133141
raise AssertionError(f"check_event_attributes: key {key} not found in actuals")
@@ -145,21 +153,25 @@ def check_event_attributes(self, expected_dict, actual_dict):
145153
if not isinstance(actual_val, list):
146154
raise AssertionError(f"check_event_attributes: actual_val for {key} is not list")
147155
if len(expected_val) != len(actual_val):
148-
raise AssertionError(f"check_event_attributes: list lengths do not match for key {key}: expected {len(expected_val)}, actual {len(actual_val)}")
156+
raise AssertionError(
157+
f"check_event_attributes: list lengths do not match for key {key}: expected {len(expected_val)}, actual {len(actual_val)}"
158+
)
149159
for expected_list, actual_list in zip(expected_val, actual_val):
150160
self.check_event_attributes(expected_list, actual_list)
151161
elif isinstance(expected_val, str) and expected_val == "*":
152162
if actual_val == "":
153163
raise AssertionError(f"check_event_attributes: actual_val for {key} is empty")
154164
elif isinstance(expected_val, str) and expected_val == "+":
155165
assert isinstance(actual_val, numbers.Number), f"The {key} is not a number."
156-
assert actual_val > 0, f"The {key} is <0 {actual_val}"
166+
assert actual_val > 0, f"The {key} is <0 {actual_val}"
157167
elif expected_val != actual_val:
158168
if isinstance(expected_val, dict):
159169
expected_val = json.dumps(expected_val)
160170
if isinstance(actual_val, dict):
161171
actual_val = json.dumps(actual_val)
162-
raise AssertionError(f"check_event_attributes: values do not match for key {key}: {expected_val} != {actual_val}")
172+
raise AssertionError(
173+
f"check_event_attributes: values do not match for key {key}: {expected_val} != {actual_val}"
174+
)
163175

164176
def check_span_events(self, span, expected_events):
165177
print("Checking span: " + span.name)

sdk/ai/azure-ai-agents/tests/test_agents_client.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3907,7 +3907,7 @@ def _get_mcp_tool(self):
39073907
server_url="https://gitmcp.io/Azure/azure-rest-api-specs",
39083908
allowed_tools=[], # Optional: specify allowed tools
39093909
)
3910-
3910+
39113911
@agentClientPreparer()
39123912
@recorded_by_proxy
39133913
def test_mcp_tool(self, **kwargs):
@@ -3928,16 +3928,20 @@ def test_mcp_tool(self, **kwargs):
39283928
content="Please summarize the Azure REST API specifications Readme",
39293929
)
39303930
mcp_tool.update_headers("SuperSecret", "123456")
3931-
run = agents_client.runs.create(thread_id=thread.id, agent_id=agent.id, tool_resources=mcp_tool.resources)
3931+
run = agents_client.runs.create(
3932+
thread_id=thread.id, agent_id=agent.id, tool_resources=mcp_tool.resources
3933+
)
39323934
was_approved = False
39333935
while run.status in [RunStatus.QUEUED, RunStatus.IN_PROGRESS, RunStatus.REQUIRES_ACTION]:
39343936
time.sleep(self._sleep_time())
39353937
run = agents_client.runs.get(thread_id=thread.id, run_id=run.id)
3936-
3937-
if run.status == RunStatus.REQUIRES_ACTION and isinstance(run.required_action, SubmitToolApprovalAction):
3938+
3939+
if run.status == RunStatus.REQUIRES_ACTION and isinstance(
3940+
run.required_action, SubmitToolApprovalAction
3941+
):
39383942
tool_calls = run.required_action.submit_tool_approval.tool_calls
39393943
assert tool_calls, "No tool calls to approve."
3940-
3944+
39413945
tool_approvals = []
39423946
for tool_call in tool_calls:
39433947
if isinstance(tool_call, RequiredMcpToolCall):
@@ -3948,15 +3952,15 @@ def test_mcp_tool(self, **kwargs):
39483952
headers=mcp_tool.headers,
39493953
)
39503954
)
3951-
3955+
39523956
if tool_approvals:
39533957
was_approved = True
39543958
agents_client.runs.submit_tool_outputs(
39553959
thread_id=thread.id, run_id=run.id, tool_approvals=tool_approvals
39563960
)
39573961
assert was_approved, "The run was never approved."
39583962
assert run.status != RunStatus.FAILED, run.last_error
3959-
3963+
39603964
is_activity_step_found = False
39613965
is_tool_call_step_found = False
39623966
for run_step in agents_client.run_steps.list(thread_id=thread.id, run_id=run.id):
@@ -3996,7 +4000,9 @@ def test_mcp_tool_streaming(self, **kwargs):
39964000
mcp_tool.update_headers("SuperSecret", "123456")
39974001

39984002
try:
3999-
with agents_client.runs.stream(thread_id=thread.id, agent_id=agent.id, tool_resources=mcp_tool.resources) as stream:
4003+
with agents_client.runs.stream(
4004+
thread_id=thread.id, agent_id=agent.id, tool_resources=mcp_tool.resources
4005+
) as stream:
40004006
is_started = False
40014007
received_message = False
40024008
got_expected_delta = False
@@ -4005,29 +4011,29 @@ def test_mcp_tool_streaming(self, **kwargs):
40054011
found_activity_details = False
40064012
found_tool_call_step = False
40074013
for event_type, event_data, _ in stream:
4008-
4014+
40094015
if isinstance(event_data, MessageDeltaChunk):
40104016
received_message = True
4011-
4017+
40124018
elif isinstance(event_data, RunStepDeltaChunk):
40134019
tool_calls_details = getattr(event_data.delta.step_details, "tool_calls")
40144020
if isinstance(tool_calls_details, list):
40154021
for tool_call in tool_calls_details:
40164022
if isinstance(tool_call, RunStepDeltaMcpToolCall):
40174023
got_expected_delta = True
4018-
4024+
40194025
elif isinstance(event_data, ThreadRun):
40204026
if event_type == AgentStreamEvent.THREAD_RUN_CREATED:
40214027
is_started = True
40224028
if event_data.status == RunStatus.FAILED:
40234029
raise AssertionError(event_data.last_error)
4024-
4030+
40254031
if event_data.status == RunStatus.REQUIRES_ACTION and isinstance(
40264032
event_data.required_action, SubmitToolApprovalAction
40274033
):
40284034
tool_calls = event_data.required_action.submit_tool_approval.tool_calls
40294035
assert tool_calls, "No tool calls to approve."
4030-
4036+
40314037
tool_approvals = []
40324038
for tool_call in tool_calls:
40334039
if isinstance(tool_call, RequiredMcpToolCall):
@@ -4038,7 +4044,7 @@ def test_mcp_tool_streaming(self, **kwargs):
40384044
headers=mcp_tool.headers,
40394045
)
40404046
)
4041-
4047+
40424048
if tool_approvals:
40434049
# Once we receive 'requires_action' status, the next event will be DONE.
40444050
# Here we associate our existing event handler to the next stream.
@@ -4048,7 +4054,7 @@ def test_mcp_tool_streaming(self, **kwargs):
40484054
tool_approvals=tool_approvals,
40494055
event_handler=stream,
40504056
)
4051-
4057+
40524058
elif isinstance(event_data, RunStep):
40534059
if event_type == AgentStreamEvent.THREAD_RUN_STEP_CREATED:
40544060
is_run_step_created = True
@@ -4059,11 +4065,10 @@ def test_mcp_tool_streaming(self, **kwargs):
40594065
for tool_call in step_details.tool_calls:
40604066
if isinstance(tool_call, RunStepMcpToolCall):
40614067
found_tool_call_step = True
4062-
4063-
4068+
40644069
elif event_type == AgentStreamEvent.ERROR:
40654070
raise AssertionError(event_data)
4066-
4071+
40674072
elif event_type == AgentStreamEvent.DONE:
40684073
is_completed = True
40694074

0 commit comments

Comments
 (0)