Skip to content

Commit d3e6fc6

Browse files
committed
Simplified logic away from request
1 parent 1b94fc8 commit d3e6fc6

File tree

1 file changed

+68
-49
lines changed

1 file changed

+68
-49
lines changed

flask_graphql/graphqlview.py

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def render_graphiql(self, **kwargs):
7171

7272
def dispatch_request(self):
7373
try:
74-
if request.method.lower() not in ('get', 'post'):
74+
request_method = request.method.lower()
75+
if request_method not in ('get', 'post'):
7576
raise HttpError(
7677
405,
7778
'GraphQL only supports GET and POST requests.',
@@ -95,7 +96,15 @@ def dispatch_request(self):
9596
'Batch requests are not allowed.'
9697
)
9798

98-
responses = [self.get_response(request, entry, show_graphiql) for entry in data]
99+
only_allow_query = request_method == 'get'
100+
101+
responses = [self.get_response(
102+
self.execute,
103+
entry,
104+
show_graphiql,
105+
only_allow_query,
106+
) for entry in data]
107+
99108
response, status_codes = zip(*responses)
100109
status_code = max(status_codes)
101110

@@ -106,7 +115,7 @@ def dispatch_request(self):
106115
result = self.json_encode(response, pretty)
107116

108117
if show_graphiql:
109-
query, variables, operation_name, id = self.get_graphql_params(request, data[0])
118+
query, variables, operation_name, id = self.get_graphql_params(data[0])
110119
return self.render_graphiql(
111120
query=query,
112121
variables=variables,
@@ -130,15 +139,23 @@ def dispatch_request(self):
130139
content_type='application/json'
131140
)
132141

133-
def get_response(self, request, data, show_graphiql=False):
134-
query, variables, operation_name, id = self.get_graphql_params(request, data)
135-
execution_result = self.execute_graphql_request(
136-
data,
137-
query,
138-
variables,
139-
operation_name,
140-
show_graphiql
141-
)
142+
def get_response(self, execute, data, show_graphiql=False, only_allow_query=False):
143+
query, variables, operation_name, id = self.get_graphql_params(data)
144+
try:
145+
execution_result = self.execute_graphql_request(
146+
self.schema,
147+
execute,
148+
data,
149+
query,
150+
variables,
151+
operation_name,
152+
only_allow_query,
153+
)
154+
except HttpError:
155+
if show_graphiql:
156+
execution_result = None
157+
else:
158+
raise
142159
return self.format_execution_result(execution_result, id)
143160

144161
def format_execution_result(self, execution_result, id):
@@ -163,20 +180,11 @@ def format_execution_result(self, execution_result, id):
163180

164181
return response, status_code
165182

166-
@staticmethod
167-
def json_encode(data, pretty=False):
168-
if not pretty:
169-
return json.dumps(data, separators=(',', ':'))
170-
171-
return json.dumps(
172-
data,
173-
indent=2,
174-
separators=(',', ': ')
175-
)
176-
177183
# noinspection PyBroadException
178184
def parse_body(self, request):
179-
content_type = self.get_content_type(request)
185+
# We use mimetype here since we don't need the other
186+
# information provided by content_type
187+
content_type = request.mimetype
180188
if content_type == 'application/graphql':
181189
return {'query': request.data.decode()}
182190

@@ -197,19 +205,31 @@ def parse_body(self, request):
197205

198206
return {}
199207

200-
def execute(self, *args, **kwargs):
201-
return execute(self.schema, *args, **kwargs)
208+
def execute(self, schema, *args, **kwargs):
209+
root_value = self.get_root_value(request)
210+
context_value = self.get_context(request)
211+
middleware = self.get_middleware(request)
212+
executor = self.get_executor(request)
213+
214+
return execute(
215+
schema,
216+
*args,
217+
root_value=root_value,
218+
context_value=context_value,
219+
middleware=middleware,
220+
executor=executor,
221+
**kwargs
222+
)
202223

203-
def execute_graphql_request(self, data, query, variables, operation_name, show_graphiql=False):
224+
@staticmethod
225+
def execute_graphql_request(schema, execute, data, query, variables, operation_name, only_allow_query=False):
204226
if not query:
205-
if show_graphiql:
206-
return None
207227
raise HttpError(400, 'Must provide query string.')
208228

209229
try:
210230
source = Source(query, name='GraphQL request')
211231
ast = parse(source)
212-
validation_errors = validate(self.schema, ast)
232+
validation_errors = validate(schema, ast)
213233
if validation_errors:
214234
return ExecutionResult(
215235
errors=validation_errors,
@@ -218,11 +238,9 @@ def execute_graphql_request(self, data, query, variables, operation_name, show_g
218238
except Exception as e:
219239
return ExecutionResult(errors=[e], invalid=True)
220240

221-
if request.method.lower() == 'get':
241+
if only_allow_query:
222242
operation_ast = get_operation_ast(ast, operation_name)
223243
if operation_ast and operation_ast.operation != 'query':
224-
if show_graphiql:
225-
return None
226244
raise HttpError(
227245
405,
228246
'Can only perform a {} operation from a POST request.'.format(operation_ast.operation),
@@ -232,33 +250,40 @@ def execute_graphql_request(self, data, query, variables, operation_name, show_g
232250
)
233251

234252
try:
235-
return self.execute(
253+
return execute(
254+
schema,
236255
ast,
237-
root_value=self.get_root_value(request),
238-
variable_values=variables or {},
239256
operation_name=operation_name,
240-
context_value=self.get_context(request),
241-
middleware=self.get_middleware(request),
242-
executor=self.get_executor(request)
257+
variable_values=variables,
243258
)
244259
except Exception as e:
245260
return ExecutionResult(errors=[e], invalid=True)
246261

262+
@staticmethod
263+
def json_encode(data, pretty=False):
264+
if not pretty:
265+
return json.dumps(data, separators=(',', ':'))
266+
267+
return json.dumps(
268+
data,
269+
indent=2,
270+
separators=(',', ': ')
271+
)
272+
247273
@classmethod
248274
def can_display_graphiql(cls, data):
249-
raw = 'raw' in request.args or 'raw' in data
250-
return not raw and cls.request_wants_html(request)
275+
return 'raw' not in data and cls.request_wants_html()
251276

252277
@classmethod
253-
def request_wants_html(cls, request):
278+
def request_wants_html(cls):
254279
best = request.accept_mimetypes \
255280
.best_match(['application/json', 'text/html'])
256281
return best == 'text/html' and \
257282
request.accept_mimetypes[best] > \
258283
request.accept_mimetypes['application/json']
259284

260285
@staticmethod
261-
def get_graphql_params(request, data):
286+
def get_graphql_params(data):
262287
query = data.get('query')
263288
variables = data.get('variables')
264289
id = data.get('id')
@@ -279,9 +304,3 @@ def format_error(error):
279304
return format_graphql_error(error)
280305

281306
return {'message': six.text_type(error)}
282-
283-
@staticmethod
284-
def get_content_type(request):
285-
# We use mimetype here since we don't need the other
286-
# information provided by content_type
287-
return request.mimetype

0 commit comments

Comments
 (0)