Skip to content

Commit 3092887

Browse files
committed
Refactor extend_schema implementation
Replicates graphql/graphql-js@ba7b8c1
1 parent a7fe66f commit 3092887

File tree

1 file changed

+136
-119
lines changed

1 file changed

+136
-119
lines changed

graphql/utilities/extend_schema.py

Lines changed: 136 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from collections import defaultdict
2-
from functools import partial
32
from itertools import chain
4-
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast
3+
from typing import Any, Dict, List, Optional, Tuple, cast
54

65
from ..language import (
76
DirectiveDefinitionNode,
@@ -14,11 +13,9 @@
1413
)
1514
from ..type import (
1615
GraphQLArgument,
17-
GraphQLArgumentMap,
1816
GraphQLDirective,
1917
GraphQLEnumType,
2018
GraphQLField,
21-
GraphQLFieldMap,
2219
GraphQLInputField,
2320
GraphQLInputObjectType,
2421
GraphQLInterfaceType,
@@ -165,141 +162,161 @@ def extend_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
165162

166163
def extend_directive(directive: GraphQLDirective) -> GraphQLDirective:
167164
kwargs = directive.to_kwargs()
168-
kwargs["args"] = extend_args(kwargs["args"])
169-
return GraphQLDirective(**kwargs)
165+
return GraphQLDirective( # type: ignore
166+
**{
167+
**kwargs,
168+
"args": {name: extend_arg(arg) for name, arg in kwargs["args"].items()},
169+
}
170+
)
170171

171172
def extend_input_object_type(
172173
type_: GraphQLInputObjectType
173174
) -> GraphQLInputObjectType:
174175
kwargs = type_.to_kwargs()
175-
extension_nodes = type_extensions_map.get(kwargs["name"], ())
176-
177-
def fields_thunk():
178-
fields = {}
179-
for field_name, field in type_.fields.items():
180-
field_kwargs = field.to_kwargs()
181-
field_kwargs["type_"] = extend_type(field_kwargs["type_"])
182-
fields[field_name] = GraphQLInputField(**field_kwargs)
183-
184-
# If there are any extensions to the fields, apply those here.
185-
for extension in extension_nodes:
186-
for field in extension.fields:
187-
fields[field.name.value] = ast_builder.build_input_field(field)
188-
189-
return fields
190-
191-
kwargs["fields"] = fields_thunk
192-
kwargs["extension_ast_nodes"] += tuple(extension_nodes)
193-
return GraphQLInputObjectType(**kwargs)
176+
extensions = type_extensions_map.get(kwargs["name"], [])
177+
field_nodes = chain.from_iterable(node.fields or [] for node in extensions)
178+
179+
return GraphQLInputObjectType(
180+
**{
181+
**kwargs,
182+
"fields": lambda: {
183+
**{
184+
name: GraphQLInputField( # type: ignore
185+
**{**field.to_kwargs(), "type_": extend_type(field.type)}
186+
)
187+
for name, field in kwargs["fields"].items()
188+
},
189+
**{
190+
field.name.value: ast_builder.build_input_field(field)
191+
for field in field_nodes
192+
},
193+
},
194+
"extension_ast_nodes": kwargs["extension_ast_nodes"]
195+
+ tuple(extensions),
196+
}
197+
)
194198

195199
def extend_enum_type(type_: GraphQLEnumType) -> GraphQLEnumType:
196200
kwargs = type_.to_kwargs()
197-
extension_nodes = type_extensions_map.get(kwargs["name"], ())
198-
values = kwargs["values"]
199-
200-
# If there are any extensions to the values, apply those here.
201-
for extension in extension_nodes:
202-
for value in extension.values:
203-
values[value.name.value] = ast_builder.build_enum_value(value)
204-
205-
kwargs["extension_ast_nodes"] += tuple(extension_nodes)
206-
return GraphQLEnumType(**kwargs)
201+
extensions = type_extensions_map.get(kwargs["name"], [])
202+
value_nodes = chain.from_iterable(node.values or [] for node in extensions)
203+
204+
return GraphQLEnumType(
205+
**{
206+
**kwargs,
207+
"values": {
208+
**kwargs["values"],
209+
**{
210+
value.name.value: ast_builder.build_enum_value(value)
211+
for value in value_nodes
212+
},
213+
},
214+
"extension_ast_nodes": kwargs["extension_ast_nodes"]
215+
+ tuple(extensions),
216+
}
217+
)
207218

208219
def extend_scalar_type(type_: GraphQLScalarType) -> GraphQLScalarType:
209220
kwargs = type_.to_kwargs()
210-
extension_nodes = type_extensions_map.get(kwargs["name"], ())
211-
212-
kwargs["extension_ast_nodes"] += tuple(extension_nodes)
213-
return GraphQLScalarType(**kwargs)
221+
extensions = type_extensions_map.get(kwargs["name"], [])
222+
223+
return GraphQLScalarType(
224+
**{
225+
**kwargs,
226+
"extension_ast_nodes": kwargs["extension_ast_nodes"]
227+
+ tuple(extensions),
228+
}
229+
)
214230

215231
def extend_object_type(type_: GraphQLObjectType) -> GraphQLObjectType:
216232
kwargs = type_.to_kwargs()
217-
extension_nodes = type_extensions_map.get(kwargs["name"], ())
218-
219-
def interfaces_thunk():
220-
interfaces: List[GraphQLInterfaceType] = list(
221-
map(
222-
cast(
223-
Callable[[GraphQLNamedType], GraphQLInterfaceType],
224-
extend_named_type,
225-
),
226-
type_.interfaces,
227-
)
228-
)
229-
230-
# If there are any extensions to the interfaces, apply those here.
231-
for extension in type_extensions_map[type_.name]:
232-
for named_type in extension.interfaces:
233-
# Note: While this could make early assertions to get the correctly
234-
# typed values, that would throw immediately while type system
235-
# validation with `validate_schema()` will produce more actionable
236-
# results.
237-
interfaces.append(
238-
cast(GraphQLInterfaceType, build_type(named_type))
239-
)
240-
241-
return interfaces
242-
243-
kwargs["interfaces"] = interfaces_thunk
244-
kwargs["fields"] = partial(extend_field_map, type_)
245-
kwargs["extension_ast_nodes"] += tuple(extension_nodes)
246-
return GraphQLObjectType(**kwargs)
233+
extensions = type_extensions_map.get(kwargs["name"], [])
234+
interface_nodes = chain.from_iterable(
235+
node.interfaces or [] for node in extensions
236+
)
237+
field_nodes = chain.from_iterable(node.fields or [] for node in extensions)
238+
239+
return GraphQLObjectType(
240+
**{
241+
**kwargs,
242+
"interfaces": lambda: [
243+
extend_named_type(interface) for interface in kwargs["interfaces"]
244+
]
245+
# Note: While this could make early assertions to get the correctly
246+
# typed values, that would throw immediately while type system
247+
# validation with validate_schema will produce more actionable results.
248+
+ [build_type(node) for node in interface_nodes],
249+
"fields": lambda: {
250+
**{
251+
name: extend_field(field)
252+
for name, field in kwargs["fields"].items()
253+
},
254+
**{
255+
node.name.value: ast_builder.build_field(node)
256+
for node in field_nodes
257+
},
258+
},
259+
"extension_ast_nodes": kwargs["extension_ast_nodes"]
260+
+ tuple(extensions),
261+
}
262+
)
247263

248264
def extend_interface_type(type_: GraphQLInterfaceType) -> GraphQLInterfaceType:
249265
kwargs = type_.to_kwargs()
250-
extension_nodes = type_extensions_map.get(kwargs["name"], ())
251-
252-
kwargs["fields"] = partial(extend_field_map, type_)
253-
kwargs["extension_ast_nodes"] += tuple(extension_nodes)
254-
return GraphQLInterfaceType(**kwargs)
266+
extensions = type_extensions_map.get(kwargs["name"], [])
267+
field_nodes = chain.from_iterable(node.fields or [] for node in extensions)
268+
269+
return GraphQLInterfaceType(
270+
**{
271+
**kwargs,
272+
"fields": lambda: {
273+
**{
274+
name: extend_field(field)
275+
for name, field in kwargs["fields"].items()
276+
},
277+
**{
278+
node.name.value: ast_builder.build_field(node)
279+
for node in field_nodes
280+
},
281+
},
282+
"extension_ast_nodes": kwargs["extension_ast_nodes"]
283+
+ tuple(extensions),
284+
}
285+
)
255286

256287
def extend_union_type(type_: GraphQLUnionType) -> GraphQLUnionType:
257288
kwargs = type_.to_kwargs()
258-
extension_nodes = type_extensions_map.get(kwargs["name"], ())
259-
260-
def types_thunk():
261-
types = list(map(extend_named_type, type_.types))
262-
263-
# If there are any extensions to the union, apply those here.
264-
for extension in extension_nodes:
265-
for named_type in extension.types:
266-
# Note: While this could make early assertions to get the correctly
267-
# typed values, that would throw immediately while type system
268-
# validation with `validate_schema()` will produce more actionable
269-
# results.
270-
types.append(build_type(named_type))
271-
272-
return types
273-
274-
kwargs["types"] = types_thunk
275-
kwargs["extension_ast_nodes"] += tuple(extension_nodes)
276-
return GraphQLUnionType(**kwargs)
277-
278-
def extend_args(old_args: GraphQLArgumentMap) -> GraphQLArgumentMap:
279-
args = {}
280-
for arg_name, arg in old_args.items():
281-
arg_kwargs = arg.to_kwargs()
282-
arg_kwargs["type_"] = extend_type(arg_kwargs["type_"])
283-
args[arg_name] = GraphQLArgument(**arg_kwargs)
284-
return args
285-
286-
def extend_field_map(
287-
type_: Union[GraphQLObjectType, GraphQLInterfaceType]
288-
) -> GraphQLFieldMap:
289-
fields = {}
290-
for field_name, field in type_.fields.items():
291-
field_kwargs = field.to_kwargs()
292-
field_kwargs["type_"] = extend_type(field_kwargs["type_"])
293-
field_kwargs["args"] = extend_args(field_kwargs["args"])
294-
fields[field_name] = GraphQLField(**field_kwargs)
295-
296-
# If there are any extensions to the fields, apply those here.
297-
build_field = ast_builder.build_field
298-
for extension in type_extensions_map.get(type_.name, ()):
299-
for field in extension.fields:
300-
fields[field.name.value] = build_field(field)
301-
302-
return fields
289+
extensions = type_extensions_map.get(kwargs["name"], [])
290+
type_nodes = chain.from_iterable(node.types or [] for node in extensions)
291+
292+
return GraphQLUnionType(
293+
**{
294+
**kwargs,
295+
"types": lambda: [
296+
extend_named_type(member_type) for member_type in kwargs["types"]
297+
]
298+
# Note: While this could make early assertions to get the correctly
299+
# typed values, that would throw immediately while type system
300+
# validation with validate_schema will produce more actionable results.
301+
+ [build_type(node) for node in type_nodes],
302+
"extension_ast_nodes": kwargs["extension_ast_nodes"]
303+
+ tuple(extensions),
304+
}
305+
)
306+
307+
def extend_field(field: GraphQLField) -> GraphQLField:
308+
return GraphQLField( # type: ignore
309+
**{
310+
**field.to_kwargs(),
311+
"type_": extend_type(field.type),
312+
"args": {name: extend_arg(arg) for name, arg in field.args.items()},
313+
}
314+
)
315+
316+
def extend_arg(arg: GraphQLArgument) -> GraphQLArgument:
317+
return GraphQLArgument( # type: ignore
318+
**{**arg.to_kwargs(), "type_": extend_type(arg.type)}
319+
)
303320

304321
# noinspection PyTypeChecker,PyUnresolvedReferences
305322
def extend_type(type_def: GraphQLType) -> GraphQLType:

0 commit comments

Comments
 (0)