Skip to content

Commit 35403fa

Browse files
committed
Fix batch request content issues
1 parent ece5e3a commit 35403fa

File tree

1 file changed

+40
-24
lines changed

1 file changed

+40
-24
lines changed

src/msgraph_core/requests/batch_request_content.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from kiota_abstractions.serialization import Parsable, ParseNode
66
from kiota_abstractions.serialization import SerializationWriter
77

8+
from urllib.request import Request
9+
810
from .batch_request_item import BatchRequestItem
911

1012

@@ -15,27 +17,36 @@ class BatchRequestContent(Parsable):
1517

1618
MAX_REQUESTS = 20
1719

18-
def __init__(self, requests: Dict[str, Union['BatchRequestItem', 'RequestInformation']] = {}):
20+
def __init__(self, requests: Dict[str, Union[BatchRequestItem, RequestInformation]] = {}):
1921
"""
2022
Initializes a new instance of the BatchRequestContent class.
23+
Args:
24+
Requests (Dict[str, Union[BatchRequestItem, RequestInformation]]): The requests to add.
2125
"""
22-
self._requests: Dict[str, Union[BatchRequestItem, 'RequestInformation']] = requests or {}
26+
self._requests: Dict[str, BatchRequestItem] = {}
2327

2428
self.is_finalized = False
2529
for request_id, request in requests.items():
30+
if isinstance(request, RequestInformation):
31+
self.add_request_information(request, request_id)
32+
continue
2633
self.add_request(request_id, request)
2734

2835
@property
29-
def requests(self) -> Dict:
36+
def requests(self) -> Dict[str, BatchRequestItem]:
3037
"""
3138
Gets the requests.
39+
Returns:
40+
Dict[str, BatchRequestItem]: requests in the batch request content.
3241
"""
3342
return self._requests
3443

3544
@requests.setter
3645
def requests(self, requests: List[BatchRequestItem]) -> None:
3746
"""
3847
Sets the requests.
48+
Args:
49+
requests (List[BatchRequestItem]): The requests to set.
3950
"""
4051
if len(requests) >= BatchRequestContent.MAX_REQUESTS:
4152
raise ValueError(f"Maximum number of requests is {BatchRequestContent.MAX_REQUESTS}")
@@ -45,49 +56,54 @@ def requests(self, requests: List[BatchRequestItem]) -> None:
4556
def add_request(self, request_id: Optional[str], request: BatchRequestItem) -> None:
4657
"""
4758
Adds a request to the batch request content.
59+
Args:
60+
request_id (Optional[str]): The request id to add.
61+
request (BatchRequestItem): The request to add.
4862
"""
4963
if len(self.requests) >= BatchRequestContent.MAX_REQUESTS:
5064
raise RuntimeError(f"Maximum number of requests is {BatchRequestContent.MAX_REQUESTS}")
5165
if not request.id:
52-
request.id = str(uuid.uuid4())
66+
request.id = request_id if request_id else str(uuid.uuid4())
5367
if hasattr(request, 'depends_on') and request.depends_on:
5468
for dependent_id in request.depends_on:
55-
if dependent_id not in self.requests:
56-
dependent_request = self._request_by_id(dependent_id)
57-
if dependent_request:
58-
self._requests[dependent_id] = dependent_request
69+
if not self._request_by_id(dependent_id):
70+
raise ValueError(f"Request depends on request id: {dependent_id} which was not found in requests. Add request id: {dependent_id} first")
5971
self._requests[request.id] = request
6072

61-
def add_request_information(self, request_information: RequestInformation) -> None:
62-
"""
73+
def add_request_information(self, request_information: RequestInformation, request_id: Optional[str] = None) -> None:
74+
"""
6375
Adds a request to the batch request content.
6476
Args:
6577
request_information (RequestInformation): The request information to add.
78+
request_id: Optional[str]: The request id to add.
6679
"""
67-
request_id = str(uuid.uuid4())
80+
request_id = request_id if request_id else str(uuid.uuid4())
6881
self.add_request(request_id, BatchRequestItem(request_information))
6982

70-
def add_urllib_request(self, request) -> None:
83+
def add_urllib_request(self, request: Request, request_id: Optional[str] = None) -> None:
7184
"""
7285
Adds a request to the batch request content.
86+
Args:
87+
request (Request): The request to add.
88+
request_id: Optional[str]: The request id to add.
7389
"""
74-
request_id = str(uuid.uuid4())
90+
request_id = request_id if request_id else str(uuid.uuid4())
7591
self.add_request(request_id, BatchRequestItem.create_with_urllib_request(request))
7692

7793
def remove(self, request_id: str) -> None:
7894
"""
7995
Removes a request from the batch request content.
80-
Also removes the request from the depends_on list of
96+
Also removes the request from the depends_on list of
8197
other requests.
98+
Args:
99+
request_id (str): The request id to remove.
82100
"""
83-
request_to_remove = None
84-
for request in self.requests:
85-
if request.id == request_id:
86-
request_to_remove = request
87-
if hasattr(request, 'depends_on') and request.depends_on:
88-
if request_id in request.depends_on:
89-
request.depends_on.remove(request_id)
101+
request_to_remove = self._request_by_id(request_id)
90102
if request_to_remove:
103+
if hasattr(request_to_remove, 'depends_on') and request_to_remove.depends_on:
104+
for dependent_id in request_to_remove.depends_on:
105+
if self._request_by_id(dependent_id):
106+
del self._requests[dependent_id]
91107
del self._requests[request_to_remove.id]
92108
else:
93109
raise ValueError(f"Request ID {request_id} not found in requests.")
@@ -108,12 +124,12 @@ def finalize(self):
108124
def _request_by_id(self, request_id: str) -> Optional[BatchRequestItem]:
109125
"""
110126
Finds a request by its ID.
111-
127+
112128
Args:
113129
request_id (str): The ID of the request to find.
114130
115131
Returns:
116-
The request with the given ID, or None if not found.
132+
Optional[BatchRequestItem]: The request with the given ID, or None if not found.
117133
"""
118134
return self._requests.get(request_id)
119135

@@ -137,4 +153,4 @@ def serialize(self, writer: SerializationWriter) -> None:
137153
Args:
138154
writer: Serialization writer to use to serialize this model
139155
"""
140-
writer.write_collection_of_object_values("requests", self.requests)
156+
writer.write_collection_of_object_values("requests", list(self.requests.values()))

0 commit comments

Comments
 (0)