1
1
# -*- coding: utf-8 -*-
2
2
from ..error import GraphQLError
3
3
from ..language import ast
4
+ from ..pyutils .default_ordered_dict import DefaultOrderedDict
4
5
from ..type .definition import GraphQLInterfaceType , GraphQLUnionType
5
6
from ..type .directives import GraphQLIncludeDirective , GraphQLSkipDirective
6
7
from ..type .introspection import (SchemaMetaFieldDef , TypeMetaFieldDef ,
@@ -18,9 +19,9 @@ class ExecutionContext(object):
18
19
and the fragments defined in the query document"""
19
20
20
21
__slots__ = 'schema' , 'fragments' , 'root_value' , 'operation' , 'variable_values' , 'errors' , 'context_value' , \
21
- 'argument_values_cache' , 'executor'
22
+ 'argument_values_cache' , 'executor' , 'middleware' , '_subfields_cache'
22
23
23
- def __init__ (self , schema , document_ast , root_value , context_value , variable_values , operation_name , executor ):
24
+ def __init__ (self , schema , document_ast , root_value , context_value , variable_values , operation_name , executor , middleware ):
24
25
"""Constructs a ExecutionContext object from the arguments passed
25
26
to execute, which we will pass throughout the other execution
26
27
methods."""
@@ -63,6 +64,13 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
63
64
self .context_value = context_value
64
65
self .argument_values_cache = {}
65
66
self .executor = executor
67
+ self .middleware = middleware
68
+ self ._subfields_cache = {}
69
+
70
+ def get_field_resolver (self , field_resolver ):
71
+ if not self .middleware :
72
+ return field_resolver
73
+ return self .middleware .get_field_resolver (field_resolver )
66
74
67
75
def get_argument_values (self , field_def , field_ast ):
68
76
k = field_def , field_ast
@@ -74,6 +82,21 @@ def get_argument_values(self, field_def, field_ast):
74
82
75
83
return result
76
84
85
+ def get_sub_fields (self , return_type , field_asts ):
86
+ k = return_type , tuple (field_asts )
87
+ if k not in self ._subfields_cache :
88
+ subfield_asts = DefaultOrderedDict (list )
89
+ visited_fragment_names = set ()
90
+ for field_ast in field_asts :
91
+ selection_set = field_ast .selection_set
92
+ if selection_set :
93
+ subfield_asts = collect_fields (
94
+ self , return_type , selection_set ,
95
+ subfield_asts , visited_fragment_names
96
+ )
97
+ self ._subfields_cache [k ] = subfield_asts
98
+ return self ._subfields_cache [k ]
99
+
77
100
78
101
class ExecutionResult (object ):
79
102
"""The result of execution. `data` is the result of executing the
@@ -245,6 +268,8 @@ def get_field_entry_key(node):
245
268
246
269
247
270
class ResolveInfo (object ):
271
+ __slots__ = ('field_name' , 'field_asts' , 'return_type' , 'parent_type' ,
272
+ 'schema' , 'fragments' , 'root_value' , 'operation' , 'variable_values' )
248
273
249
274
def __init__ (self , field_name , field_asts , return_type , parent_type ,
250
275
schema , fragments , root_value , operation , variable_values ):
@@ -277,10 +302,10 @@ def get_field_def(schema, parent_type, field_name):
277
302
are allowed, like on a Union. __schema could get automatically
278
303
added to the query type, but that would require mutating type
279
304
definitions, which would cause issues."""
280
- if field_name == SchemaMetaFieldDef . name and schema .get_query_type () == parent_type :
305
+ if field_name == '__schema' and schema .get_query_type () == parent_type :
281
306
return SchemaMetaFieldDef
282
- elif field_name == TypeMetaFieldDef . name and schema .get_query_type () == parent_type :
307
+ elif field_name == '__type' and schema .get_query_type () == parent_type :
283
308
return TypeMetaFieldDef
284
- elif field_name == TypeNameMetaFieldDef . name :
309
+ elif field_name == '__typename' :
285
310
return TypeNameMetaFieldDef
286
- return parent_type .get_fields () .get (field_name )
311
+ return parent_type .fields .get (field_name )
0 commit comments