1
- from typing import Any , Callable , Dict , List , NoReturn , Optional , Union , cast
1
+ from typing import Callable , Dict , List , NoReturn , Optional , Union , cast
2
2
3
3
from ..language import (
4
4
DirectiveDefinitionNode ,
57
57
)
58
58
from .value_from_ast import value_from_ast
59
59
60
- TypeDefinitionsMap = Dict [str , TypeDefinitionNode ]
61
60
TypeResolver = Callable [[str ], GraphQLNamedType ]
62
61
63
62
__all__ = [
@@ -98,34 +97,39 @@ def build_ast_schema(
98
97
assert_valid_sdl (document_ast )
99
98
100
99
schema_def : Optional [SchemaDefinitionNode ] = None
101
- node_map : TypeDefinitionsMap = {}
100
+ type_defs : List [ TypeDefinitionNode ] = []
102
101
directive_defs : List [DirectiveDefinitionNode ] = []
103
102
append_directive_def = directive_defs .append
104
103
for def_ in document_ast .definitions :
105
104
if isinstance (def_ , SchemaDefinitionNode ):
106
105
schema_def = def_
107
106
elif isinstance (def_ , TypeDefinitionNode ):
108
107
def_ = cast (TypeDefinitionNode , def_ )
109
- node_map [ def_ . name . value ] = def_
108
+ type_defs . append ( def_ )
110
109
elif isinstance (def_ , DirectiveDefinitionNode ):
111
110
append_directive_def (def_ )
112
111
112
+ def resolve_type (type_name : str ) -> GraphQLNamedType :
113
+ type_ = type_map .get (type_name )
114
+ if not type :
115
+ raise TypeError (f"Type '{ type_name } ' not found in document." )
116
+ return type_
117
+
118
+ ast_builder = ASTDefinitionBuilder (
119
+ assume_valid = assume_valid , resolve_type = resolve_type
120
+ )
121
+
122
+ type_map = {node .name .value : ast_builder .build_type (node ) for node in type_defs }
123
+
113
124
if schema_def :
114
- operation_types : Dict [ OperationType , Any ] = get_operation_types (schema_def )
125
+ operation_types = get_operation_types (schema_def )
115
126
else :
116
127
operation_types = {
117
- OperationType .QUERY : node_map . get ( "Query" ) ,
118
- OperationType .MUTATION : node_map . get ( "Mutation" ) ,
119
- OperationType .SUBSCRIPTION : node_map . get ( "Subscription" ) ,
128
+ OperationType .QUERY : "Query" ,
129
+ OperationType .MUTATION : "Mutation" ,
130
+ OperationType .SUBSCRIPTION : "Subscription" ,
120
131
}
121
132
122
- def resolve_type (type_name : str ):
123
- raise TypeError (f"Type '{ type_name } ' not found in document." )
124
-
125
- ast_builder = ASTDefinitionBuilder (
126
- node_map , assume_valid = assume_valid , resolve_type = resolve_type
127
- )
128
-
129
133
directives = [
130
134
ast_builder .build_directive (directive_def ) for directive_def in directive_defs
131
135
]
@@ -138,35 +142,31 @@ def resolve_type(type_name: str):
138
142
if not any (directive .name == "deprecated" for directive in directives ):
139
143
directives .append (GraphQLDeprecatedDirective )
140
144
141
- # Note: While this could make early assertions to get the correctly typed values
142
- # below, that would throw immediately while type system validation with
143
- # `validate_schema()` will produce more actionable results.
144
145
query_type = operation_types .get (OperationType .QUERY )
145
146
mutation_type = operation_types .get (OperationType .MUTATION )
146
147
subscription_type = operation_types .get (OperationType .SUBSCRIPTION )
147
148
return GraphQLSchema (
148
- query = cast (GraphQLObjectType , ast_builder .build_type (query_type ))
149
- if query_type
150
- else None ,
151
- mutation = cast (GraphQLObjectType , ast_builder .build_type (mutation_type ))
149
+ # Note: While this could make early assertions to get the correctly
150
+ # typed values below, that would throw immediately while type system
151
+ # validation with `validate_schema()` will produce more actionable results.
152
+ query = cast (GraphQLObjectType , type_map .get (query_type )) if query_type else None ,
153
+ mutation = cast (GraphQLObjectType , type_map .get (mutation_type ))
152
154
if mutation_type
153
155
else None ,
154
- subscription = cast (GraphQLObjectType , ast_builder . build_type (subscription_type ))
156
+ subscription = cast (GraphQLObjectType , type_map . get (subscription_type ))
155
157
if subscription_type
156
158
else None ,
157
- types = [ ast_builder . build_type ( node ) for node in node_map .values ()] ,
159
+ types = list ( type_map .values ()) ,
158
160
directives = directives ,
159
161
ast_node = schema_def ,
160
162
assume_valid = assume_valid ,
161
163
)
162
164
163
165
164
- def get_operation_types (
165
- schema : SchemaDefinitionNode
166
- ) -> Dict [OperationType , NamedTypeNode ]:
167
- op_types : Dict [OperationType , NamedTypeNode ] = {}
166
+ def get_operation_types (schema : SchemaDefinitionNode ) -> Dict [OperationType , str ]:
167
+ op_types : Dict [OperationType , str ] = {}
168
168
for operation_type in schema .operation_types :
169
- op_types [operation_type .operation ] = operation_type .type
169
+ op_types [operation_type .operation ] = operation_type .type . name . value
170
170
return op_types
171
171
172
172
@@ -175,48 +175,34 @@ def default_type_resolver(type_name: str, *_args) -> NoReturn:
175
175
raise TypeError (f"Type '{ type_name } ' not found in document." )
176
176
177
177
178
+ std_type_map : Dict [str , Union [GraphQLNamedType , GraphQLObjectType ]] = {
179
+ ** specified_scalar_types ,
180
+ ** introspection_types ,
181
+ }
182
+
183
+
178
184
class ASTDefinitionBuilder :
179
185
def __init__ (
180
186
self ,
181
- type_definitions_map : TypeDefinitionsMap ,
182
187
assume_valid : bool = False ,
183
188
resolve_type : TypeResolver = default_type_resolver ,
184
189
) -> None :
185
- self ._type_definitions_map = type_definitions_map
186
190
self ._assume_valid = assume_valid
187
191
self ._resolve_type = resolve_type
188
- # Initialize to the GraphQL built in scalars and introspection types.
189
- self ._cache : Dict [str , GraphQLNamedType ] = {
190
- ** specified_scalar_types ,
191
- ** introspection_types ,
192
- }
193
192
194
- def build_type (
195
- self , node : Union [NamedTypeNode , TypeDefinitionNode ]
196
- ) -> GraphQLNamedType :
197
- type_name = node .name .value
198
- cache = self ._cache
199
- if type_name not in cache :
200
- if isinstance (node , NamedTypeNode ):
201
- def_node = self ._type_definitions_map .get (type_name )
202
- cache [type_name ] = (
203
- self ._make_schema_def (def_node )
204
- if def_node
205
- else self ._resolve_type (node .name .value )
206
- )
207
- else :
208
- cache [type_name ] = self ._make_schema_def (node )
209
- return cache [type_name ]
210
-
211
- def _build_wrapped_type (self , type_node : TypeNode ) -> GraphQLType :
212
- if isinstance (type_node , ListTypeNode ):
213
- return GraphQLList (self ._build_wrapped_type (type_node .type ))
214
- if isinstance (type_node , NonNullTypeNode ):
193
+ def get_named_type (self , node : NamedTypeNode ) -> GraphQLNamedType :
194
+ name = node .name .value
195
+ return std_type_map .get (name ) or self ._resolve_type (name )
196
+
197
+ def get_wrapped_type (self , node : TypeNode ) -> GraphQLType :
198
+ if isinstance (node , ListTypeNode ):
199
+ return GraphQLList (self .get_wrapped_type (node .type ))
200
+ if isinstance (node , NonNullTypeNode ):
215
201
return GraphQLNonNull (
216
202
# Note: GraphQLNonNull constructor validates this type
217
- cast (GraphQLNullableType , self ._build_wrapped_type ( type_node .type ))
203
+ cast (GraphQLNullableType , self .get_wrapped_type ( node .type ))
218
204
)
219
- return self .build_type (cast (NamedTypeNode , type_node ))
205
+ return self .get_named_type (cast (NamedTypeNode , node ))
220
206
221
207
def build_directive (self , directive : DirectiveDefinitionNode ) -> GraphQLDirective :
222
208
locations = [DirectiveLocation [node .value ] for node in directive .locations ]
@@ -235,7 +221,7 @@ def build_field(self, field: FieldDefinitionNode) -> GraphQLField:
235
221
# Note: While this could make assertions to get the correctly typed value, that
236
222
# would throw immediately while type system validation with `validate_schema()`
237
223
# will produce more actionable results.
238
- type_ = self ._build_wrapped_type (field .type )
224
+ type_ = self .get_wrapped_type (field .type )
239
225
type_ = cast (GraphQLOutputType , type_ )
240
226
return GraphQLField (
241
227
type_ = type_ ,
@@ -249,7 +235,7 @@ def build_arg(self, value: InputValueDefinitionNode) -> GraphQLArgument:
249
235
# Note: While this could make assertions to get the correctly typed value, that
250
236
# would throw immediately while type system validation with `validate_schema()`
251
237
# will produce more actionable results.
252
- type_ = self ._build_wrapped_type (value .type )
238
+ type_ = self .get_wrapped_type (value .type )
253
239
type_ = cast (GraphQLInputType , type_ )
254
240
return GraphQLArgument (
255
241
type_ = type_ ,
@@ -262,7 +248,7 @@ def build_input_field(self, value: InputValueDefinitionNode) -> GraphQLInputFiel
262
248
# Note: While this could make assertions to get the correctly typed value, that
263
249
# would throw immediately while type system validation with `validate_schema()`
264
250
# will produce more actionable results.
265
- type_ = self ._build_wrapped_type (value .type )
251
+ type_ = self .get_wrapped_type (value .type )
266
252
type_ = cast (GraphQLInputType , type_ )
267
253
return GraphQLInputField (
268
254
type_ = type_ ,
@@ -279,7 +265,11 @@ def build_enum_value(value: EnumValueDefinitionNode) -> GraphQLEnumValue:
279
265
ast_node = value ,
280
266
)
281
267
282
- def _make_schema_def (self , ast_node : TypeDefinitionNode ) -> GraphQLNamedType :
268
+ def build_type (self , ast_node : TypeDefinitionNode ) -> GraphQLNamedType :
269
+ name = ast_node .name .value
270
+ if name in std_type_map :
271
+ return std_type_map [name ]
272
+
283
273
method = {
284
274
"object_type_definition" : self ._make_type_def ,
285
275
"interface_type_definition" : self ._make_interface_def ,
@@ -289,6 +279,7 @@ def _make_schema_def(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
289
279
"input_object_type_definition" : self ._make_input_object_def ,
290
280
}.get (ast_node .kind )
291
281
if not method :
282
+ # Not reachable. All possible type definition nodes have been considered.
292
283
raise TypeError (f"Type kind '{ ast_node .kind } ' not supported." )
293
284
return method (ast_node ) # type: ignore
294
285
@@ -302,7 +293,7 @@ def _make_type_def(self, ast_node: ObjectTypeDefinitionNode) -> GraphQLObjectTyp
302
293
interfaces = cast (
303
294
Thunk [GraphQLInterfaceList ],
304
295
(
305
- (lambda : [self .build_type (ref ) for ref in interface_nodes ])
296
+ (lambda : [self .get_named_type (ref ) for ref in interface_nodes ])
306
297
if interface_nodes
307
298
else []
308
299
),
@@ -373,7 +364,7 @@ def _make_union_def(self, type_def: UnionTypeDefinitionNode) -> GraphQLUnionType
373
364
# `validate_schema()` will get more actionable results.
374
365
types = cast (
375
366
Thunk [GraphQLTypeList ],
376
- (lambda : [self .build_type (ref ) for ref in type_nodes ])
367
+ (lambda : [self .get_named_type (ref ) for ref in type_nodes ])
377
368
if type_nodes
378
369
else [],
379
370
)
0 commit comments