Skip to content

Commit 7b2e6de

Browse files
authored
[Agents] Fix ToolOutput regression, add streaming ComputerTool classes (#43119)
* [Agents] Fix ToolOutput regression, add streaming ComputerTool classes * correct patch method signature * update Changelog
1 parent b37a488 commit 7b2e6de

18 files changed

+655
-472
lines changed

sdk/ai/azure-ai-agents/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
### Bugs Fixed
1010
- Added `RunStepDeltaChunk` to `StreamEventData` model (GitHub issues [43022](https://github.com/Azure/azure-sdk-for-python/issues/43022))
11-
11+
- Fixed regression, reverted ToolOutput type signature and usage in tool_output submission.
12+
- Added `RunStepDeltaComputerUseDetails` and `RunStepDeltaComputerUseToolCall` classes for streaming computer use scenarios.
1213

1314
## 1.2.0b4 (2025-09-12)
1415

sdk/ai/azure-ai-agents/apiview-properties.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"azure.ai.agents.models.CodeInterpreterToolDefinition": "Azure.AI.Agents.CodeInterpreterToolDefinition",
3636
"azure.ai.agents.models.CodeInterpreterToolResource": "Azure.AI.Agents.CodeInterpreterToolResource",
3737
"azure.ai.agents.models.ComputerScreenshot": "Azure.AI.Agents.ComputerScreenshot",
38-
"azure.ai.agents.models.ToolOutput": "Azure.AI.Agents.ToolOutput",
38+
"azure.ai.agents.models.StructuredToolOutput": "Azure.AI.Agents.StructuredToolOutput",
3939
"azure.ai.agents.models.ComputerToolOutput": "Azure.AI.Agents.ComputerToolOutput",
4040
"azure.ai.agents.models.ComputerUseToolDefinition": "Azure.AI.Agents.ComputerUseToolDefinition",
4141
"azure.ai.agents.models.ComputerUseToolParameters": "Azure.AI.Agents.ComputerUseToolParameters",
@@ -59,7 +59,6 @@
5959
"azure.ai.agents.models.FunctionDefinition": "Azure.AI.Agents.FunctionDefinition",
6060
"azure.ai.agents.models.FunctionName": "Azure.AI.Agents.FunctionName",
6161
"azure.ai.agents.models.FunctionToolDefinition": "Azure.AI.Agents.FunctionToolDefinition",
62-
"azure.ai.agents.models.FunctionToolOutput": "Azure.AI.Agents.FunctionToolOutput",
6362
"azure.ai.agents.models.IncompleteRunDetails": "Azure.AI.Agents.IncompleteRunDetails",
6463
"azure.ai.agents.models.KeyPressAction": "Azure.AI.Agents.KeyPressAction",
6564
"azure.ai.agents.models.MCPApprovalPerTool": "Azure.AI.Agents.MCPApprovalPerTool",
@@ -155,6 +154,8 @@
155154
"azure.ai.agents.models.RunStepDeltaCodeInterpreterImageOutputObject": "Azure.AI.Agents.RunStepDeltaCodeInterpreterImageOutputObject",
156155
"azure.ai.agents.models.RunStepDeltaCodeInterpreterLogOutput": "Azure.AI.Agents.RunStepDeltaCodeInterpreterLogOutput",
157156
"azure.ai.agents.models.RunStepDeltaCodeInterpreterToolCall": "Azure.AI.Agents.RunStepDeltaCodeInterpreterToolCall",
157+
"azure.ai.agents.models.RunStepDeltaComputerUseDetails": "Azure.AI.Agents.RunStepDeltaComputerUseDetails",
158+
"azure.ai.agents.models.RunStepDeltaComputerUseToolCall": "Azure.AI.Agents.RunStepDeltaComputerUseToolCall",
158159
"azure.ai.agents.models.RunStepDeltaConnectedAgentToolCall": "Azure.AI.Agents.RunStepDeltaConnectedAgentToolCall",
159160
"azure.ai.agents.models.RunStepDeltaCustomBingGroundingToolCall": "Azure.AI.Agents.RunStepDeltaCustomBingGroundingToolCall",
160161
"azure.ai.agents.models.RunStepDeltaDeepResearchToolCall": "Azure.AI.Agents.RunStepDeltaDeepResearchToolCall",
@@ -199,6 +200,7 @@
199200
"azure.ai.agents.models.ThreadRun": "Azure.AI.Agents.ThreadRun",
200201
"azure.ai.agents.models.ToolApproval": "Azure.AI.Agents.ToolApproval",
201202
"azure.ai.agents.models.ToolConnection": "Azure.AI.Agents.ToolConnection",
203+
"azure.ai.agents.models.ToolOutput": "Azure.AI.Agents.ToolOutput",
202204
"azure.ai.agents.models.ToolResources": "Azure.AI.Agents.ToolResources",
203205
"azure.ai.agents.models.TruncationObject": "Azure.AI.Agents.TruncationObject",
204206
"azure.ai.agents.models.TypeAction": "Azure.AI.Agents.TypeAction",

sdk/ai/azure-ai-agents/azure/ai/agents/_types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# pylint: disable=line-too-long,useless-suppression
21
# coding=utf-8
32
# --------------------------------------------------------------------------
43
# Copyright (c) Microsoft Corporation. All rights reserved.
@@ -7,12 +6,12 @@
76
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
87
# --------------------------------------------------------------------------
98

10-
from typing import List, Literal, TYPE_CHECKING, Union
9+
from typing import Literal, TYPE_CHECKING, Union
1110

1211
if TYPE_CHECKING:
1312
from . import models as _models
1413
MCPRequiredApproval = Union[str, Literal["never"], Literal["always"], "_models.MCPApprovalPerTool"]
15-
MessageInputContent = Union[str, List["_models.MessageInputContentBlock"]]
14+
MessageInputContent = Union[str, list["_models.MessageInputContentBlock"]]
1615
MessageAttachmentToolDefinition = Union["_models.CodeInterpreterToolDefinition", "_models.FileSearchToolDefinition"]
1716
AgentsToolChoiceOption = Union[str, str, "_models.AgentsToolChoiceOptionMode", "_models.AgentsNamedToolChoice"]
1817
AgentsResponseFormatOption = Union[

sdk/ai/azure-ai-agents/azure/ai/agents/_utils/model_base.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def _get_model(module_name: str, model_name: str):
346346

347347

348348
class _MyMutableMapping(MutableMapping[str, typing.Any]):
349-
def __init__(self, data: typing.Dict[str, typing.Any]) -> None:
349+
def __init__(self, data: dict[str, typing.Any]) -> None:
350350
self._data = data
351351

352352
def __contains__(self, key: typing.Any) -> bool:
@@ -426,7 +426,7 @@ def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
426426
return self._data.pop(key)
427427
return self._data.pop(key, default)
428428

429-
def popitem(self) -> typing.Tuple[str, typing.Any]:
429+
def popitem(self) -> tuple[str, typing.Any]:
430430
"""
431431
Removes and returns some (key, value) pair
432432
:returns: The (key, value) pair.
@@ -514,9 +514,7 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
514514
return o
515515

516516

517-
def _get_rest_field(
518-
attr_to_rest_field: typing.Dict[str, "_RestField"], rest_name: str
519-
) -> typing.Optional["_RestField"]:
517+
def _get_rest_field(attr_to_rest_field: dict[str, "_RestField"], rest_name: str) -> typing.Optional["_RestField"]:
520518
try:
521519
return next(rf for rf in attr_to_rest_field.values() if rf._rest_name == rest_name)
522520
except StopIteration:
@@ -539,7 +537,7 @@ class Model(_MyMutableMapping):
539537
_is_model = True
540538
# label whether current class's _attr_to_rest_field has been calculated
541539
# could not see _attr_to_rest_field directly because subclass inherits it from parent class
542-
_calculated: typing.Set[str] = set()
540+
_calculated: set[str] = set()
543541

544542
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
545543
class_name = self.__class__.__name__
@@ -624,7 +622,7 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
624622
# we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping',
625623
# 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object'
626624
mros = cls.__mro__[:-9][::-1] # ignore parents, and reverse the mro order
627-
attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
625+
attr_to_rest_field: dict[str, _RestField] = { # map attribute name to rest_field property
628626
k: v for mro_class in mros for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type")
629627
}
630628
annotations = {
@@ -639,7 +637,7 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
639637
rf._type = rf._get_deserialize_callable_from_annotation(annotations.get(attr, None))
640638
if not rf._rest_name_input:
641639
rf._rest_name_input = attr
642-
cls._attr_to_rest_field: typing.Dict[str, _RestField] = dict(attr_to_rest_field.items())
640+
cls._attr_to_rest_field: dict[str, _RestField] = dict(attr_to_rest_field.items())
643641
cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}")
644642

645643
return super().__new__(cls)
@@ -681,7 +679,7 @@ def _deserialize(cls, data, exist_discriminators):
681679
mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore # pylint: disable=no-member
682680
return mapped_cls._deserialize(data, exist_discriminators)
683681

684-
def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
682+
def as_dict(self, *, exclude_readonly: bool = False) -> dict[str, typing.Any]:
685683
"""Return a dict that can be turned into json using json.dump.
686684
687685
:keyword bool exclude_readonly: Whether to remove the readonly properties.
@@ -741,7 +739,7 @@ def _deserialize_with_union(deserializers, obj):
741739
def _deserialize_dict(
742740
value_deserializer: typing.Optional[typing.Callable],
743741
module: typing.Optional[str],
744-
obj: typing.Dict[typing.Any, typing.Any],
742+
obj: dict[typing.Any, typing.Any],
745743
):
746744
if obj is None:
747745
return obj
@@ -751,7 +749,7 @@ def _deserialize_dict(
751749

752750

753751
def _deserialize_multiple_sequence(
754-
entry_deserializers: typing.List[typing.Optional[typing.Callable]],
752+
entry_deserializers: list[typing.Optional[typing.Callable]],
755753
module: typing.Optional[str],
756754
obj,
757755
):
@@ -772,14 +770,14 @@ def _deserialize_sequence(
772770
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
773771

774772

775-
def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]:
773+
def _sorted_annotations(types: list[typing.Any]) -> list[typing.Any]:
776774
return sorted(
777775
types,
778776
key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"),
779777
)
780778

781779

782-
def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-branches
780+
def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements, too-many-branches
783781
annotation: typing.Any,
784782
module: typing.Optional[str],
785783
rf: typing.Optional["_RestField"] = None,
@@ -844,7 +842,10 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
844842
return functools.partial(_deserialize_with_union, deserializers)
845843

846844
try:
847-
if annotation._name == "Dict": # pyright: ignore
845+
annotation_name = (
846+
annotation.__name__ if hasattr(annotation, "__name__") else annotation._name # pyright: ignore
847+
)
848+
if annotation_name.lower() == "dict":
848849
value_deserializer = _get_deserialize_callable_from_annotation(
849850
annotation.__args__[1], module, rf # pyright: ignore
850851
)
@@ -857,7 +858,10 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
857858
except (AttributeError, IndexError):
858859
pass
859860
try:
860-
if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
861+
annotation_name = (
862+
annotation.__name__ if hasattr(annotation, "__name__") else annotation._name # pyright: ignore
863+
)
864+
if annotation_name.lower() in ["list", "set", "tuple", "sequence"]:
861865
if len(annotation.__args__) > 1: # pyright: ignore
862866
entry_deserializers = [
863867
_get_deserialize_callable_from_annotation(dt, module, rf)
@@ -975,11 +979,11 @@ def __init__(
975979
name: typing.Optional[str] = None,
976980
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
977981
is_discriminator: bool = False,
978-
visibility: typing.Optional[typing.List[str]] = None,
982+
visibility: typing.Optional[list[str]] = None,
979983
default: typing.Any = _UNSET,
980984
format: typing.Optional[str] = None,
981985
is_multipart_file_input: bool = False,
982-
xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
986+
xml: typing.Optional[dict[str, typing.Any]] = None,
983987
):
984988
self._type = type
985989
self._rest_name_input = name
@@ -1037,11 +1041,11 @@ def rest_field(
10371041
*,
10381042
name: typing.Optional[str] = None,
10391043
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
1040-
visibility: typing.Optional[typing.List[str]] = None,
1044+
visibility: typing.Optional[list[str]] = None,
10411045
default: typing.Any = _UNSET,
10421046
format: typing.Optional[str] = None,
10431047
is_multipart_file_input: bool = False,
1044-
xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
1048+
xml: typing.Optional[dict[str, typing.Any]] = None,
10451049
) -> typing.Any:
10461050
return _RestField(
10471051
name=name,
@@ -1058,8 +1062,8 @@ def rest_discriminator(
10581062
*,
10591063
name: typing.Optional[str] = None,
10601064
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
1061-
visibility: typing.Optional[typing.List[str]] = None,
1062-
xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
1065+
visibility: typing.Optional[list[str]] = None,
1066+
xml: typing.Optional[dict[str, typing.Any]] = None,
10631067
) -> typing.Any:
10641068
return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml)
10651069

@@ -1078,9 +1082,9 @@ def serialize_xml(model: Model, exclude_readonly: bool = False) -> str:
10781082
def _get_element(
10791083
o: typing.Any,
10801084
exclude_readonly: bool = False,
1081-
parent_meta: typing.Optional[typing.Dict[str, typing.Any]] = None,
1085+
parent_meta: typing.Optional[dict[str, typing.Any]] = None,
10821086
wrapped_element: typing.Optional[ET.Element] = None,
1083-
) -> typing.Union[ET.Element, typing.List[ET.Element]]:
1087+
) -> typing.Union[ET.Element, list[ET.Element]]:
10841088
if _is_model(o):
10851089
model_meta = getattr(o, "_xml", {})
10861090

@@ -1169,7 +1173,7 @@ def _get_element(
11691173
def _get_wrapped_element(
11701174
v: typing.Any,
11711175
exclude_readonly: bool,
1172-
meta: typing.Optional[typing.Dict[str, typing.Any]],
1176+
meta: typing.Optional[dict[str, typing.Any]],
11731177
) -> ET.Element:
11741178
wrapped_element = _create_xml_element(
11751179
meta.get("name") if meta else None, meta.get("prefix") if meta else None, meta.get("ns") if meta else None
@@ -1212,7 +1216,7 @@ def _deserialize_xml(
12121216
def _convert_element(e: ET.Element):
12131217
# dict case
12141218
if len(e.attrib) > 0 or len({child.tag for child in e}) > 1:
1215-
dict_result: typing.Dict[str, typing.Any] = {}
1219+
dict_result: dict[str, typing.Any] = {}
12161220
for child in e:
12171221
if dict_result.get(child.tag) is not None:
12181222
if isinstance(dict_result[child.tag], list):
@@ -1225,7 +1229,7 @@ def _convert_element(e: ET.Element):
12251229
return dict_result
12261230
# array case
12271231
if len(e) > 0:
1228-
array_result: typing.List[typing.Any] = []
1232+
array_result: list[typing.Any] = []
12291233
for child in e:
12301234
array_result.append(_convert_element(child))
12311235
return array_result

sdk/ai/azure-ai-agents/azure/ai/agents/_utils/serialization.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import sys
2222
import codecs
2323
from typing import (
24-
Dict,
2524
Any,
2625
cast,
2726
Optional,
@@ -31,7 +30,6 @@
3130
Mapping,
3231
Callable,
3332
MutableMapping,
34-
List,
3533
)
3634

3735
try:
@@ -229,12 +227,12 @@ class Model:
229227
serialization and deserialization.
230228
"""
231229

232-
_subtype_map: Dict[str, Dict[str, Any]] = {}
233-
_attribute_map: Dict[str, Dict[str, Any]] = {}
234-
_validation: Dict[str, Dict[str, Any]] = {}
230+
_subtype_map: dict[str, dict[str, Any]] = {}
231+
_attribute_map: dict[str, dict[str, Any]] = {}
232+
_validation: dict[str, dict[str, Any]] = {}
235233

236234
def __init__(self, **kwargs: Any) -> None:
237-
self.additional_properties: Optional[Dict[str, Any]] = {}
235+
self.additional_properties: Optional[dict[str, Any]] = {}
238236
for k in kwargs: # pylint: disable=consider-using-dict-items
239237
if k not in self._attribute_map:
240238
_LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__)
@@ -311,7 +309,7 @@ def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON:
311309
def as_dict(
312310
self,
313311
keep_readonly: bool = True,
314-
key_transformer: Callable[[str, Dict[str, Any], Any], Any] = attribute_transformer,
312+
key_transformer: Callable[[str, dict[str, Any], Any], Any] = attribute_transformer,
315313
**kwargs: Any
316314
) -> JSON:
317315
"""Return a dict that can be serialized using json.dump.
@@ -380,7 +378,7 @@ def deserialize(cls, data: Any, content_type: Optional[str] = None) -> Self:
380378
def from_dict(
381379
cls,
382380
data: Any,
383-
key_extractors: Optional[Callable[[str, Dict[str, Any], Any], Any]] = None,
381+
key_extractors: Optional[Callable[[str, dict[str, Any], Any], Any]] = None,
384382
content_type: Optional[str] = None,
385383
) -> Self:
386384
"""Parse a dict using given key extractor return a model.
@@ -414,7 +412,7 @@ def _flatten_subtype(cls, key, objects):
414412
return {}
415413
result = dict(cls._subtype_map[key])
416414
for valuetype in cls._subtype_map[key].values():
417-
result.update(objects[valuetype]._flatten_subtype(key, objects)) # pylint: disable=protected-access
415+
result |= objects[valuetype]._flatten_subtype(key, objects) # pylint: disable=protected-access
418416
return result
419417

420418
@classmethod
@@ -528,7 +526,7 @@ def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None:
528526
"[]": self.serialize_iter,
529527
"{}": self.serialize_dict,
530528
}
531-
self.dependencies: Dict[str, type] = dict(classes) if classes else {}
529+
self.dependencies: dict[str, type] = dict(classes) if classes else {}
532530
self.key_transformer = full_restapi_key_transformer
533531
self.client_side_validation = True
534532

@@ -579,7 +577,7 @@ def _serialize( # pylint: disable=too-many-nested-blocks, too-many-branches, to
579577

580578
if attr_name == "additional_properties" and attr_desc["key"] == "":
581579
if target_obj.additional_properties is not None:
582-
serialized.update(target_obj.additional_properties)
580+
serialized |= target_obj.additional_properties
583581
continue
584582
try:
585583

@@ -789,7 +787,7 @@ def serialize_data(self, data, data_type, **kwargs):
789787

790788
# If dependencies is empty, try with current data class
791789
# It has to be a subclass of Enum anyway
792-
enum_type = self.dependencies.get(data_type, data.__class__)
790+
enum_type = self.dependencies.get(data_type, cast(type, data.__class__))
793791
if issubclass(enum_type, Enum):
794792
return Serializer.serialize_enum(data, enum_obj=enum_type)
795793

@@ -1184,7 +1182,7 @@ def rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argumen
11841182

11851183
while "." in key:
11861184
# Need the cast, as for some reasons "split" is typed as list[str | Any]
1187-
dict_keys = cast(List[str], _FLATTEN.split(key))
1185+
dict_keys = cast(list[str], _FLATTEN.split(key))
11881186
if len(dict_keys) == 1:
11891187
key = _decode_attribute_map_key(dict_keys[0])
11901188
break
@@ -1386,7 +1384,7 @@ def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None:
13861384
"duration": (isodate.Duration, datetime.timedelta),
13871385
"iso-8601": (datetime.datetime),
13881386
}
1389-
self.dependencies: Dict[str, type] = dict(classes) if classes else {}
1387+
self.dependencies: dict[str, type] = dict(classes) if classes else {}
13901388
self.key_extractors = [rest_key_extractor, xml_key_extractor]
13911389
# Additional properties only works if the "rest_key_extractor" is used to
13921390
# extract the keys. Making it to work whatever the key extractor is too much

0 commit comments

Comments
 (0)