9
9
Mapping ,
10
10
Optional ,
11
11
Tuple ,
12
+ TypeVar ,
12
13
Union ,
13
14
cast ,
14
15
)
@@ -133,6 +134,38 @@ def extend_schema(
133
134
)
134
135
135
136
137
+ TEN = TypeVar ("TEN" , bound = TypeExtensionNode )
138
+
139
+
140
+ class TypeExtensionsMap :
141
+ """Mappings from types to their extensions."""
142
+
143
+ scalar : DefaultDict [str , List [ScalarTypeExtensionNode ]]
144
+ object : DefaultDict [str , List [ObjectTypeExtensionNode ]]
145
+ interface : DefaultDict [str , List [InterfaceTypeExtensionNode ]]
146
+ union : DefaultDict [str , List [UnionTypeExtensionNode ]]
147
+ enum : DefaultDict [str , List [EnumTypeExtensionNode ]]
148
+ input_object : DefaultDict [str , List [InputObjectTypeExtensionNode ]]
149
+
150
+ def __init__ (self ) -> None :
151
+ self .scalar = defaultdict (list )
152
+ self .object = defaultdict (list )
153
+ self .interface = defaultdict (list )
154
+ self .union = defaultdict (list )
155
+ self .enum = defaultdict (list )
156
+ self .input_object = defaultdict (list )
157
+
158
+ def for_node (self , node : TEN ) -> DefaultDict [str , List [TEN ]]:
159
+ """Get type extensions map for the given node kind."""
160
+ kind = node .kind
161
+ try :
162
+ kind = kind .removesuffix ("_type_extension" )
163
+ except AttributeError : # pragma: no cover (Python < 3.9)
164
+ if kind .endswith ("_type_extension" ):
165
+ kind = kind [:- 15 ]
166
+ return getattr (self , kind )
167
+
168
+
136
169
class ExtendSchemaImpl :
137
170
"""Helper class implementing the methods to extend a schema.
138
171
@@ -143,11 +176,11 @@ class ExtendSchemaImpl:
143
176
"""
144
177
145
178
type_map : Dict [str , GraphQLNamedType ]
146
- type_extensions_map : Dict [ str , Any ]
179
+ type_extensions : TypeExtensionsMap
147
180
148
- def __init__ (self , type_extensions_map : Dict [ str , Any ] ):
181
+ def __init__ (self , type_extensions : TypeExtensionsMap ):
149
182
self .type_map = {}
150
- self .type_extensions_map = type_extensions_map
183
+ self .type_extensions = type_extensions
151
184
152
185
@classmethod
153
186
def extend_schema_args (
@@ -164,7 +197,8 @@ def extend_schema_args(
164
197
165
198
# Collect the type definitions and extensions found in the document.
166
199
type_defs : List [TypeDefinitionNode ] = []
167
- type_extensions_map : DefaultDict [str , Any ] = defaultdict (list )
200
+
201
+ type_extensions = TypeExtensionsMap ()
168
202
169
203
# New directives and types are separate because a directives and types can have
170
204
# the same name. For example, a type named "skip".
@@ -174,31 +208,28 @@ def extend_schema_args(
174
208
# Schema extensions are collected which may add additional operation types.
175
209
schema_extensions : List [SchemaExtensionNode ] = []
176
210
211
+ is_schema_changed = False
177
212
for def_ in document_ast .definitions :
178
213
if isinstance (def_ , SchemaDefinitionNode ):
179
214
schema_def = def_
180
215
elif isinstance (def_ , SchemaExtensionNode ):
181
216
schema_extensions .append (def_ )
217
+ elif isinstance (def_ , DirectiveDefinitionNode ):
218
+ directive_defs .append (def_ )
182
219
elif isinstance (def_ , TypeDefinitionNode ):
183
220
type_defs .append (def_ )
184
221
elif isinstance (def_ , TypeExtensionNode ):
185
- extended_type_name = def_ .name .value
186
- type_extensions_map [ extended_type_name ]. append ( def_ )
187
- elif isinstance ( def_ , DirectiveDefinitionNode ):
188
- directive_defs . append ( def_ )
222
+ type_extensions . for_node ( def_ )[ def_ .name .value ]. append ( def_ )
223
+ else :
224
+ continue
225
+ is_schema_changed = True
189
226
190
227
# If this document contains no new types, extensions, or directives then return
191
228
# the same unmodified GraphQLSchema instance.
192
- if (
193
- not type_extensions_map
194
- and not type_defs
195
- and not directive_defs
196
- and not schema_extensions
197
- and not schema_def
198
- ):
229
+ if not is_schema_changed :
199
230
return schema_kwargs
200
231
201
- self = cls (type_extensions_map )
232
+ self = cls (type_extensions )
202
233
for existing_type in schema_kwargs ["types" ] or ():
203
234
self .type_map [existing_type .name ] = self .extend_named_type (existing_type )
204
235
for type_node in type_defs :
@@ -311,7 +342,7 @@ def extend_input_object_type(
311
342
type_ : GraphQLInputObjectType ,
312
343
) -> GraphQLInputObjectType :
313
344
kwargs = type_ .to_kwargs ()
314
- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
345
+ extensions = tuple (self .type_extensions . input_object [kwargs ["name" ]])
315
346
316
347
return GraphQLInputObjectType (
317
348
** merge_kwargs (
@@ -325,7 +356,7 @@ def extend_input_object_type(
325
356
326
357
def extend_enum_type (self , type_ : GraphQLEnumType ) -> GraphQLEnumType :
327
358
kwargs = type_ .to_kwargs ()
328
- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
359
+ extensions = tuple (self .type_extensions . enum [kwargs ["name" ]])
329
360
330
361
return GraphQLEnumType (
331
362
** merge_kwargs (
@@ -337,7 +368,7 @@ def extend_enum_type(self, type_: GraphQLEnumType) -> GraphQLEnumType:
337
368
338
369
def extend_scalar_type (self , type_ : GraphQLScalarType ) -> GraphQLScalarType :
339
370
kwargs = type_ .to_kwargs ()
340
- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
371
+ extensions = tuple (self .type_extensions . scalar [kwargs ["name" ]])
341
372
342
373
specified_by_url = kwargs ["specified_by_url" ]
343
374
for extension_node in extensions :
@@ -373,7 +404,7 @@ def extend_object_type_fields(
373
404
# noinspection PyShadowingNames
374
405
def extend_object_type (self , type_ : GraphQLObjectType ) -> GraphQLObjectType :
375
406
kwargs = type_ .to_kwargs ()
376
- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
407
+ extensions = tuple (self .type_extensions . object [kwargs ["name" ]])
377
408
378
409
return GraphQLObjectType (
379
410
** merge_kwargs (
@@ -410,7 +441,7 @@ def extend_interface_type(
410
441
self , type_ : GraphQLInterfaceType
411
442
) -> GraphQLInterfaceType :
412
443
kwargs = type_ .to_kwargs ()
413
- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
444
+ extensions = tuple (self .type_extensions . interface [kwargs ["name" ]])
414
445
415
446
return GraphQLInterfaceType (
416
447
** merge_kwargs (
@@ -433,7 +464,7 @@ def extend_union_type_types(
433
464
434
465
def extend_union_type (self , type_ : GraphQLUnionType ) -> GraphQLUnionType :
435
466
kwargs = type_ .to_kwargs ()
436
- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
467
+ extensions = tuple (self .type_extensions . union [kwargs ["name" ]])
437
468
438
469
return GraphQLUnionType (
439
470
** merge_kwargs (
@@ -626,7 +657,7 @@ def build_union_types(
626
657
def build_object_type (
627
658
self , ast_node : ObjectTypeDefinitionNode
628
659
) -> GraphQLObjectType :
629
- extension_nodes = self .type_extensions_map [ast_node .name .value ]
660
+ extension_nodes = self .type_extensions . object [ast_node .name .value ]
630
661
all_nodes : List [Union [ObjectTypeDefinitionNode , ObjectTypeExtensionNode ]] = [
631
662
ast_node ,
632
663
* extension_nodes ,
@@ -644,7 +675,7 @@ def build_interface_type(
644
675
self ,
645
676
ast_node : InterfaceTypeDefinitionNode ,
646
677
) -> GraphQLInterfaceType :
647
- extension_nodes = self .type_extensions_map [ast_node .name .value ]
678
+ extension_nodes = self .type_extensions . interface [ast_node .name .value ]
648
679
all_nodes : List [
649
680
Union [InterfaceTypeDefinitionNode , InterfaceTypeExtensionNode ]
650
681
] = [ast_node , * extension_nodes ]
@@ -658,7 +689,7 @@ def build_interface_type(
658
689
)
659
690
660
691
def build_enum_type (self , ast_node : EnumTypeDefinitionNode ) -> GraphQLEnumType :
661
- extension_nodes = self .type_extensions_map [ast_node .name .value ]
692
+ extension_nodes = self .type_extensions . enum [ast_node .name .value ]
662
693
all_nodes : List [Union [EnumTypeDefinitionNode , EnumTypeExtensionNode ]] = [
663
694
ast_node ,
664
695
* extension_nodes ,
@@ -672,7 +703,7 @@ def build_enum_type(self, ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
672
703
)
673
704
674
705
def build_union_type (self , ast_node : UnionTypeDefinitionNode ) -> GraphQLUnionType :
675
- extension_nodes = self .type_extensions_map [ast_node .name .value ]
706
+ extension_nodes = self .type_extensions . union [ast_node .name .value ]
676
707
all_nodes : List [Union [UnionTypeDefinitionNode , UnionTypeExtensionNode ]] = [
677
708
ast_node ,
678
709
* extension_nodes ,
@@ -688,7 +719,7 @@ def build_union_type(self, ast_node: UnionTypeDefinitionNode) -> GraphQLUnionTyp
688
719
def build_scalar_type (
689
720
self , ast_node : ScalarTypeDefinitionNode
690
721
) -> GraphQLScalarType :
691
- extension_nodes = self .type_extensions_map [ast_node .name .value ]
722
+ extension_nodes = self .type_extensions . scalar [ast_node .name .value ]
692
723
return GraphQLScalarType (
693
724
name = ast_node .name .value ,
694
725
description = ast_node .description .value if ast_node .description else None ,
@@ -701,7 +732,7 @@ def build_input_object_type(
701
732
self ,
702
733
ast_node : InputObjectTypeDefinitionNode ,
703
734
) -> GraphQLInputObjectType :
704
- extension_nodes = self .type_extensions_map [ast_node .name .value ]
735
+ extension_nodes = self .type_extensions . input_object [ast_node .name .value ]
705
736
all_nodes : List [
706
737
Union [InputObjectTypeDefinitionNode , InputObjectTypeExtensionNode ]
707
738
] = [ast_node , * extension_nodes ]
0 commit comments