Skip to content

Commit b47c922

Browse files
committed
Stricter typing in extend_schema
Replicates graphql/graphql-js@c9f968b
1 parent b4cf3c3 commit b47c922

File tree

1 file changed

+59
-28
lines changed

1 file changed

+59
-28
lines changed

src/graphql/utilities/extend_schema.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Mapping,
1010
Optional,
1111
Tuple,
12+
TypeVar,
1213
Union,
1314
cast,
1415
)
@@ -133,6 +134,38 @@ def extend_schema(
133134
)
134135

135136

137+
TEN = TypeVar("TEN", bound=TypeExtensionNode)
138+
139+
140+
class TypeExtensionsMap:
141+
"""Mappings from types to their extensions."""
142+
143+
scalar: DefaultDict[str, List[ScalarTypeExtensionNode]]
144+
object: DefaultDict[str, List[ObjectTypeExtensionNode]]
145+
interface: DefaultDict[str, List[InterfaceTypeExtensionNode]]
146+
union: DefaultDict[str, List[UnionTypeExtensionNode]]
147+
enum: DefaultDict[str, List[EnumTypeExtensionNode]]
148+
input_object: DefaultDict[str, List[InputObjectTypeExtensionNode]]
149+
150+
def __init__(self) -> None:
151+
self.scalar = defaultdict(list)
152+
self.object = defaultdict(list)
153+
self.interface = defaultdict(list)
154+
self.union = defaultdict(list)
155+
self.enum = defaultdict(list)
156+
self.input_object = defaultdict(list)
157+
158+
def for_node(self, node: TEN) -> DefaultDict[str, List[TEN]]:
159+
"""Get type extensions map for the given node kind."""
160+
kind = node.kind
161+
try:
162+
kind = kind.removesuffix("_type_extension")
163+
except AttributeError: # pragma: no cover (Python < 3.9)
164+
if kind.endswith("_type_extension"):
165+
kind = kind[:-15]
166+
return getattr(self, kind)
167+
168+
136169
class ExtendSchemaImpl:
137170
"""Helper class implementing the methods to extend a schema.
138171
@@ -143,11 +176,11 @@ class ExtendSchemaImpl:
143176
"""
144177

145178
type_map: Dict[str, GraphQLNamedType]
146-
type_extensions_map: Dict[str, Any]
179+
type_extensions: TypeExtensionsMap
147180

148-
def __init__(self, type_extensions_map: Dict[str, Any]):
181+
def __init__(self, type_extensions: TypeExtensionsMap):
149182
self.type_map = {}
150-
self.type_extensions_map = type_extensions_map
183+
self.type_extensions = type_extensions
151184

152185
@classmethod
153186
def extend_schema_args(
@@ -164,7 +197,8 @@ def extend_schema_args(
164197

165198
# Collect the type definitions and extensions found in the document.
166199
type_defs: List[TypeDefinitionNode] = []
167-
type_extensions_map: DefaultDict[str, Any] = defaultdict(list)
200+
201+
type_extensions = TypeExtensionsMap()
168202

169203
# New directives and types are separate because a directives and types can have
170204
# the same name. For example, a type named "skip".
@@ -174,31 +208,28 @@ def extend_schema_args(
174208
# Schema extensions are collected which may add additional operation types.
175209
schema_extensions: List[SchemaExtensionNode] = []
176210

211+
is_schema_changed = False
177212
for def_ in document_ast.definitions:
178213
if isinstance(def_, SchemaDefinitionNode):
179214
schema_def = def_
180215
elif isinstance(def_, SchemaExtensionNode):
181216
schema_extensions.append(def_)
217+
elif isinstance(def_, DirectiveDefinitionNode):
218+
directive_defs.append(def_)
182219
elif isinstance(def_, TypeDefinitionNode):
183220
type_defs.append(def_)
184221
elif isinstance(def_, TypeExtensionNode):
185-
extended_type_name = def_.name.value
186-
type_extensions_map[extended_type_name].append(def_)
187-
elif isinstance(def_, DirectiveDefinitionNode):
188-
directive_defs.append(def_)
222+
type_extensions.for_node(def_)[def_.name.value].append(def_)
223+
else:
224+
continue
225+
is_schema_changed = True
189226

190227
# If this document contains no new types, extensions, or directives then return
191228
# the same unmodified GraphQLSchema instance.
192-
if (
193-
not type_extensions_map
194-
and not type_defs
195-
and not directive_defs
196-
and not schema_extensions
197-
and not schema_def
198-
):
229+
if not is_schema_changed:
199230
return schema_kwargs
200231

201-
self = cls(type_extensions_map)
232+
self = cls(type_extensions)
202233
for existing_type in schema_kwargs["types"] or ():
203234
self.type_map[existing_type.name] = self.extend_named_type(existing_type)
204235
for type_node in type_defs:
@@ -311,7 +342,7 @@ def extend_input_object_type(
311342
type_: GraphQLInputObjectType,
312343
) -> GraphQLInputObjectType:
313344
kwargs = type_.to_kwargs()
314-
extensions = tuple(self.type_extensions_map[kwargs["name"]])
345+
extensions = tuple(self.type_extensions.input_object[kwargs["name"]])
315346

316347
return GraphQLInputObjectType(
317348
**merge_kwargs(
@@ -325,7 +356,7 @@ def extend_input_object_type(
325356

326357
def extend_enum_type(self, type_: GraphQLEnumType) -> GraphQLEnumType:
327358
kwargs = type_.to_kwargs()
328-
extensions = tuple(self.type_extensions_map[kwargs["name"]])
359+
extensions = tuple(self.type_extensions.enum[kwargs["name"]])
329360

330361
return GraphQLEnumType(
331362
**merge_kwargs(
@@ -337,7 +368,7 @@ def extend_enum_type(self, type_: GraphQLEnumType) -> GraphQLEnumType:
337368

338369
def extend_scalar_type(self, type_: GraphQLScalarType) -> GraphQLScalarType:
339370
kwargs = type_.to_kwargs()
340-
extensions = tuple(self.type_extensions_map[kwargs["name"]])
371+
extensions = tuple(self.type_extensions.scalar[kwargs["name"]])
341372

342373
specified_by_url = kwargs["specified_by_url"]
343374
for extension_node in extensions:
@@ -373,7 +404,7 @@ def extend_object_type_fields(
373404
# noinspection PyShadowingNames
374405
def extend_object_type(self, type_: GraphQLObjectType) -> GraphQLObjectType:
375406
kwargs = type_.to_kwargs()
376-
extensions = tuple(self.type_extensions_map[kwargs["name"]])
407+
extensions = tuple(self.type_extensions.object[kwargs["name"]])
377408

378409
return GraphQLObjectType(
379410
**merge_kwargs(
@@ -410,7 +441,7 @@ def extend_interface_type(
410441
self, type_: GraphQLInterfaceType
411442
) -> GraphQLInterfaceType:
412443
kwargs = type_.to_kwargs()
413-
extensions = tuple(self.type_extensions_map[kwargs["name"]])
444+
extensions = tuple(self.type_extensions.interface[kwargs["name"]])
414445

415446
return GraphQLInterfaceType(
416447
**merge_kwargs(
@@ -433,7 +464,7 @@ def extend_union_type_types(
433464

434465
def extend_union_type(self, type_: GraphQLUnionType) -> GraphQLUnionType:
435466
kwargs = type_.to_kwargs()
436-
extensions = tuple(self.type_extensions_map[kwargs["name"]])
467+
extensions = tuple(self.type_extensions.union[kwargs["name"]])
437468

438469
return GraphQLUnionType(
439470
**merge_kwargs(
@@ -626,7 +657,7 @@ def build_union_types(
626657
def build_object_type(
627658
self, ast_node: ObjectTypeDefinitionNode
628659
) -> GraphQLObjectType:
629-
extension_nodes = self.type_extensions_map[ast_node.name.value]
660+
extension_nodes = self.type_extensions.object[ast_node.name.value]
630661
all_nodes: List[Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]] = [
631662
ast_node,
632663
*extension_nodes,
@@ -644,7 +675,7 @@ def build_interface_type(
644675
self,
645676
ast_node: InterfaceTypeDefinitionNode,
646677
) -> GraphQLInterfaceType:
647-
extension_nodes = self.type_extensions_map[ast_node.name.value]
678+
extension_nodes = self.type_extensions.interface[ast_node.name.value]
648679
all_nodes: List[
649680
Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode]
650681
] = [ast_node, *extension_nodes]
@@ -658,7 +689,7 @@ def build_interface_type(
658689
)
659690

660691
def build_enum_type(self, ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
661-
extension_nodes = self.type_extensions_map[ast_node.name.value]
692+
extension_nodes = self.type_extensions.enum[ast_node.name.value]
662693
all_nodes: List[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] = [
663694
ast_node,
664695
*extension_nodes,
@@ -672,7 +703,7 @@ def build_enum_type(self, ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
672703
)
673704

674705
def build_union_type(self, ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType:
675-
extension_nodes = self.type_extensions_map[ast_node.name.value]
706+
extension_nodes = self.type_extensions.union[ast_node.name.value]
676707
all_nodes: List[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]] = [
677708
ast_node,
678709
*extension_nodes,
@@ -688,7 +719,7 @@ def build_union_type(self, ast_node: UnionTypeDefinitionNode) -> GraphQLUnionTyp
688719
def build_scalar_type(
689720
self, ast_node: ScalarTypeDefinitionNode
690721
) -> GraphQLScalarType:
691-
extension_nodes = self.type_extensions_map[ast_node.name.value]
722+
extension_nodes = self.type_extensions.scalar[ast_node.name.value]
692723
return GraphQLScalarType(
693724
name=ast_node.name.value,
694725
description=ast_node.description.value if ast_node.description else None,
@@ -701,7 +732,7 @@ def build_input_object_type(
701732
self,
702733
ast_node: InputObjectTypeDefinitionNode,
703734
) -> GraphQLInputObjectType:
704-
extension_nodes = self.type_extensions_map[ast_node.name.value]
735+
extension_nodes = self.type_extensions.input_object[ast_node.name.value]
705736
all_nodes: List[
706737
Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode]
707738
] = [ast_node, *extension_nodes]

0 commit comments

Comments
 (0)