1
+ import json
2
+ import logging
1
3
from ssl import SSLContext
2
4
from typing import Any , AsyncGenerator , Dict , Optional , Union
3
5
8
10
from aiohttp .typedefs import LooseCookies , LooseHeaders
9
11
from graphql import DocumentNode , ExecutionResult , print_ast
10
12
13
+ from ..utils import extract_files
11
14
from .async_transport import AsyncTransport
12
15
from .exceptions import (
13
16
TransportAlreadyConnected ,
16
19
TransportServerError ,
17
20
)
18
21
22
+ log = logging .getLogger (__name__ )
23
+
19
24
20
25
class AIOHTTPTransport (AsyncTransport ):
21
26
""":ref:`Async Transport <async_transports>` to execute GraphQL queries
@@ -32,7 +37,7 @@ def __init__(
32
37
auth : Optional [BasicAuth ] = None ,
33
38
ssl : Union [SSLContext , bool , Fingerprint ] = False ,
34
39
timeout : Optional [int ] = None ,
35
- client_session_args : Dict [str , Any ] = {} ,
40
+ client_session_args : Optional [ Dict [str , Any ]] = None ,
36
41
) -> None :
37
42
"""Initialize the transport with the given aiohttp parameters.
38
43
@@ -54,7 +59,6 @@ def __init__(
54
59
self .ssl : Union [SSLContext , bool , Fingerprint ] = ssl
55
60
self .timeout : Optional [int ] = timeout
56
61
self .client_session_args = client_session_args
57
-
58
62
self .session : Optional [aiohttp .ClientSession ] = None
59
63
60
64
async def connect (self ) -> None :
@@ -81,7 +85,8 @@ async def connect(self) -> None:
81
85
)
82
86
83
87
# Adding custom parameters passed from init
84
- client_session_args .update (self .client_session_args )
88
+ if self .client_session_args :
89
+ client_session_args .update (self .client_session_args ) # type: ignore
85
90
86
91
self .session = aiohttp .ClientSession (** client_session_args )
87
92
@@ -104,7 +109,8 @@ async def execute(
104
109
document : DocumentNode ,
105
110
variable_values : Optional [Dict [str , str ]] = None ,
106
111
operation_name : Optional [str ] = None ,
107
- extra_args : Dict [str , Any ] = {},
112
+ extra_args : Dict [str , Any ] = None ,
113
+ upload_files : bool = False ,
108
114
) -> ExecutionResult :
109
115
"""Execute the provided document AST against the configured remote server
110
116
using the current session.
@@ -118,25 +124,70 @@ async def execute(
118
124
:param variables_values: An optional Dict of variable values
119
125
:param operation_name: An optional Operation name for the request
120
126
:param extra_args: additional arguments to send to the aiohttp post method
127
+ :param upload_files: Set to True if you want to put files in the variable values
121
128
:returns: an ExecutionResult object.
122
129
"""
123
130
124
131
query_str = print_ast (document )
132
+
125
133
payload : Dict [str , Any ] = {
126
134
"query" : query_str ,
127
135
}
128
136
129
- if variable_values :
130
- payload ["variables" ] = variable_values
131
137
if operation_name :
132
138
payload ["operationName" ] = operation_name
133
139
134
- post_args = {
135
- "json" : payload ,
136
- }
140
+ if upload_files :
141
+
142
+ # If the upload_files flag is set, then we need variable_values
143
+ assert variable_values is not None
144
+
145
+ # If we upload files, we will extract the files present in the
146
+ # variable_values dict and replace them by null values
147
+ nulled_variable_values , files = extract_files (variable_values )
148
+
149
+ # Save the nulled variable values in the payload
150
+ payload ["variables" ] = nulled_variable_values
151
+
152
+ # Prepare aiohttp to send multipart-encoded data
153
+ data = aiohttp .FormData ()
154
+
155
+ # Generate the file map
156
+ # path is nested in a list because the spec allows multiple pointers
157
+ # to the same file. But we don't support that.
158
+ # Will generate something like {"0": ["variables.file"]}
159
+ file_map = {str (i ): [path ] for i , path in enumerate (files )}
160
+
161
+ # Enumerate the file streams
162
+ # Will generate something like {'0': <_io.BufferedReader ...>}
163
+ file_streams = {str (i ): files [path ] for i , path in enumerate (files )}
164
+
165
+ # Add the payload to the operations field
166
+ operations_str = json .dumps (payload )
167
+ log .debug ("operations %s" , operations_str )
168
+ data .add_field (
169
+ "operations" , operations_str , content_type = "application/json"
170
+ )
171
+
172
+ # Add the file map field
173
+ file_map_str = json .dumps (file_map )
174
+ log .debug ("file_map %s" , file_map_str )
175
+ data .add_field ("map" , file_map_str , content_type = "application/json" )
176
+
177
+ # Add the extracted files as remaining fields
178
+ data .add_fields (* file_streams .items ())
179
+
180
+ post_args : Dict [str , Any ] = {"data" : data }
181
+
182
+ else :
183
+ if variable_values :
184
+ payload ["variables" ] = variable_values
185
+
186
+ post_args = {"json" : payload }
137
187
138
188
# Pass post_args to aiohttp post method
139
- post_args .update (extra_args )
189
+ if extra_args :
190
+ post_args .update (extra_args )
140
191
141
192
if self .session is None :
142
193
raise TransportClosed ("Transport is not connected" )
0 commit comments