Skip to content

Commit eaf6d1b

Browse files
committed
Removing _prepare_batch_request to avoid code duplication
1 parent 612c5bb commit eaf6d1b

File tree

3 files changed

+40
-97
lines changed

3 files changed

+40
-97
lines changed

gql/transport/aiohttp.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def connect(self) -> None:
127127

128128
# Adding custom parameters passed from init
129129
if self.client_session_args:
130-
client_session_args.update(self.client_session_args) # type: ignore
130+
client_session_args.update(self.client_session_args)
131131

132132
log.debug("Connecting transport")
133133

@@ -164,36 +164,22 @@ async def close(self) -> None:
164164

165165
self.session = None
166166

167-
def _prepare_batch_request(
168-
self,
169-
reqs: List[GraphQLRequest],
170-
extra_args: Optional[Dict[str, Any]] = None,
171-
) -> Dict[str, Any]:
172-
173-
payload = [req.payload for req in reqs]
174-
175-
post_args = {"json": payload}
176-
177-
# Log the payload
178-
if log.isEnabledFor(logging.DEBUG):
179-
log.debug(">>> %s", self.json_serialize(payload))
180-
181-
# Pass post_args to aiohttp post method
182-
if extra_args:
183-
post_args.update(extra_args)
184-
185-
return post_args
186-
187167
def _prepare_request(
188168
self,
189-
request: GraphQLRequest,
169+
request: Union[GraphQLRequest, List[GraphQLRequest]],
190170
extra_args: Optional[Dict[str, Any]] = None,
191171
upload_files: bool = False,
192172
) -> Dict[str, Any]:
193173

194-
payload = request.payload
174+
payload: Dict | List
175+
if isinstance(request, GraphQLRequest):
176+
payload = request.payload
177+
else:
178+
payload = [req.payload for req in request]
195179

196180
if upload_files:
181+
assert isinstance(payload, Dict)
182+
assert isinstance(request, GraphQLRequest)
197183
post_args = self._prepare_file_uploads(request, payload)
198184
else:
199185
post_args = {"json": payload}
@@ -416,7 +402,7 @@ async def execute_batch(
416402
if self.session is None:
417403
raise TransportClosed("Transport is not connected")
418404

419-
post_args = self._prepare_batch_request(
405+
post_args = self._prepare_request(
420406
reqs,
421407
extra_args,
422408
)

gql/transport/httpx.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,22 @@ def __init__(
5959

6060
def _prepare_request(
6161
self,
62-
req: GraphQLRequest,
62+
request: Union[GraphQLRequest, List[GraphQLRequest]],
63+
*,
6364
extra_args: Optional[Dict[str, Any]] = None,
6465
upload_files: bool = False,
6566
) -> Dict[str, Any]:
6667

67-
payload = req.payload
68+
payload: Dict | List
69+
if isinstance(request, GraphQLRequest):
70+
payload = request.payload
71+
else:
72+
payload = [req.payload for req in request]
6873

6974
if upload_files:
70-
post_args = self._prepare_file_uploads(req, payload)
75+
assert isinstance(payload, Dict)
76+
assert isinstance(request, GraphQLRequest)
77+
post_args = self._prepare_file_uploads(request, payload)
7178
else:
7279
post_args = {"json": payload}
7380

@@ -81,26 +88,6 @@ def _prepare_request(
8188

8289
return post_args
8390

84-
def _prepare_batch_request(
85-
self,
86-
reqs: List[GraphQLRequest],
87-
extra_args: Optional[Dict[str, Any]] = None,
88-
) -> Dict[str, Any]:
89-
90-
payload = [req.payload for req in reqs]
91-
92-
post_args = {"json": payload}
93-
94-
# Log the payload
95-
if log.isEnabledFor(logging.DEBUG):
96-
log.debug(">>> %s", self.json_serialize(payload))
97-
98-
# Pass post_args to httpx post method
99-
if extra_args:
100-
post_args.update(extra_args)
101-
102-
return post_args
103-
10491
def _prepare_file_uploads(
10592
self,
10693
request: GraphQLRequest,
@@ -244,7 +231,7 @@ def connect(self):
244231

245232
self.client = httpx.Client(**self.kwargs)
246233

247-
def execute( # type: ignore
234+
def execute(
248235
self,
249236
request: GraphQLRequest,
250237
*,
@@ -269,8 +256,8 @@ def execute( # type: ignore
269256

270257
post_args = self._prepare_request(
271258
request,
272-
extra_args,
273-
upload_files,
259+
extra_args=extra_args,
260+
upload_files=upload_files,
274261
)
275262

276263
try:
@@ -302,9 +289,9 @@ def execute_batch(
302289
if not self.client:
303290
raise TransportClosed("Transport is not connected")
304291

305-
post_args = self._prepare_batch_request(
292+
post_args = self._prepare_request(
306293
reqs,
307-
extra_args,
294+
extra_args=extra_args,
308295
)
309296

310297
response = self.client.post(self.url, **post_args)
@@ -361,8 +348,8 @@ async def execute(
361348

362349
post_args = self._prepare_request(
363350
request,
364-
extra_args,
365-
upload_files,
351+
extra_args=extra_args,
352+
upload_files=upload_files,
366353
)
367354

368355
try:
@@ -394,9 +381,9 @@ async def execute_batch(
394381
if not self.client:
395382
raise TransportClosed("Transport is not connected")
396383

397-
post_args = self._prepare_batch_request(
384+
post_args = self._prepare_request(
398385
reqs,
399-
extra_args,
386+
extra_args=extra_args,
400387
)
401388

402389
response = await self.client.post(self.url, **post_args)

gql/transport/requests.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -137,50 +137,20 @@ def connect(self):
137137
else:
138138
raise TransportAlreadyConnected("Transport is already connected")
139139

140-
def _prepare_batch_request(
141-
self,
142-
reqs: List[GraphQLRequest],
143-
*,
144-
timeout: Optional[int] = None,
145-
extra_args: Optional[Dict[str, Any]] = None,
146-
) -> Dict[str, Any]:
147-
148-
payload = [req.payload for req in reqs]
149-
150-
post_args: Dict[str, Any] = {
151-
"headers": self.headers,
152-
"auth": self.auth,
153-
"cookies": self.cookies,
154-
"timeout": timeout or self.default_timeout,
155-
"verify": self.verify,
156-
}
157-
158-
data_key = "json" if self.use_json else "data"
159-
post_args[data_key] = payload
160-
161-
# Log the payload
162-
if log.isEnabledFor(logging.DEBUG):
163-
log.debug(">>> %s", self.json_serialize(payload))
164-
165-
# Pass kwargs to requests post method
166-
post_args.update(self.kwargs)
167-
168-
# Pass post_args to requests post method
169-
if extra_args:
170-
post_args.update(extra_args)
171-
172-
return post_args
173-
174140
def _prepare_request(
175141
self,
176-
request: GraphQLRequest,
142+
request: Union[GraphQLRequest, List[GraphQLRequest]],
177143
*,
178144
timeout: Optional[int] = None,
179145
extra_args: Optional[Dict[str, Any]] = None,
180146
upload_files: bool = False,
181147
) -> Dict[str, Any]:
182148

183-
payload = request.payload
149+
payload: Dict | List
150+
if isinstance(request, GraphQLRequest):
151+
payload = request.payload
152+
else:
153+
payload = [req.payload for req in request]
184154

185155
post_args: Dict[str, Any] = {
186156
"headers": self.headers,
@@ -191,6 +161,8 @@ def _prepare_request(
191161
}
192162

193163
if upload_files:
164+
assert isinstance(payload, Dict)
165+
assert isinstance(request, GraphQLRequest)
194166
post_args = self._prepare_file_uploads(
195167
request=request,
196168
payload=payload,
@@ -282,7 +254,7 @@ def _prepare_file_uploads(
282254

283255
return post_args
284256

285-
def execute( # type: ignore
257+
def execute(
286258
self,
287259
request: GraphQLRequest,
288260
timeout: Optional[int] = None,
@@ -316,9 +288,7 @@ def execute( # type: ignore
316288

317289
# Using the created session to perform requests
318290
try:
319-
response = self.session.request(
320-
self.method, self.url, **post_args # type: ignore
321-
)
291+
response = self.session.request(self.method, self.url, **post_args)
322292
finally:
323293
if upload_files:
324294
close_files(list(self.files.values()))
@@ -373,7 +343,7 @@ def execute_batch(
373343
if not self.session:
374344
raise TransportClosed("Transport is not connected")
375345

376-
post_args = self._prepare_batch_request(
346+
post_args = self._prepare_request(
377347
reqs,
378348
timeout=timeout,
379349
extra_args=extra_args,

0 commit comments

Comments
 (0)