Skip to content

Commit 8de6087

Browse files
committed
Implement extendSchema. Fixed #40.
Related GraphQL-js implementation: graphql/graphql-js@9ea8196
1 parent 05152a1 commit 8de6087

File tree

2 files changed

+946
-0
lines changed

2 files changed

+946
-0
lines changed

graphql/core/utils/extend_schema.py

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
from collections import OrderedDict, defaultdict
2+
3+
from graphql.core.language import ast
4+
5+
from ..error import GraphQLError
6+
from ..type.definition import (GraphQLArgument, GraphQLEnumType,
7+
GraphQLEnumValue, GraphQLField,
8+
GraphQLInputObjectField, GraphQLInputObjectType,
9+
GraphQLInterfaceType, GraphQLList,
10+
GraphQLNonNull, GraphQLObjectType,
11+
GraphQLScalarType, GraphQLUnionType)
12+
from ..type.scalars import (GraphQLBoolean, GraphQLFloat, GraphQLID,
13+
GraphQLInt, GraphQLString)
14+
from ..type.schema import GraphQLSchema
15+
from .value_from_ast import value_from_ast
16+
17+
18+
def extend_schema(schema, documentAST=None):
19+
"""Produces a new schema given an existing schema and a document which may
20+
contain GraphQL type extensions and definitions. The original schema will
21+
remain unaltered.
22+
23+
Because a schema represents a graph of references, a schema cannot be
24+
extended without effectively making an entire copy. We do not know until it's
25+
too late if subgraphs remain unchanged.
26+
27+
This algorithm copies the provided schema, applying extensions while
28+
producing the copy. The original schema remains unaltered."""
29+
30+
assert isinstance(
31+
schema, GraphQLSchema), 'Must provide valid GraphQLSchema'
32+
assert documentAST and isinstance(
33+
documentAST, ast.Document), 'Must provide valid Document AST'
34+
35+
# Collect the type definitions and extensions found in the document.
36+
type_definition_map = {}
37+
type_extensions_map = defaultdict(list)
38+
39+
for _def in documentAST.definitions:
40+
if isinstance(_def, (
41+
ast.ObjectTypeDefinition,
42+
ast.InterfaceTypeDefinition,
43+
ast.EnumTypeDefinition,
44+
ast.UnionTypeDefinition,
45+
ast.ScalarTypeDefinition,
46+
ast.InputObjectTypeDefinition,
47+
)):
48+
# Sanity check that none of the defined types conflict with the
49+
# schema's existing types.
50+
type_name = _def.name.value
51+
if schema.get_type(type_name):
52+
raise GraphQLError(
53+
('Type "{}" already exists in the schema. It cannot also ' +
54+
'be defined in this type definition.').format(type_name),
55+
[_def]
56+
)
57+
58+
type_definition_map[type_name] = _def
59+
elif isinstance(_def, ast.TypeExtensionDefinition):
60+
# Sanity check that this type extension exists within the
61+
# schema's existing types.
62+
extended_type_name = _def.definition.name.value
63+
existing_type = schema.get_type(extended_type_name)
64+
if not existing_type:
65+
raise GraphQLError(
66+
('Cannot extend type "{}" because it does not ' +
67+
'exist in the existing schema.').format(extended_type_name),
68+
[_def.definition]
69+
)
70+
if not isinstance(existing_type, GraphQLObjectType):
71+
raise GraphQLError(
72+
'Cannot extend non-object type "{}".'.format(
73+
extended_type_name),
74+
[_def.definition]
75+
)
76+
77+
type_extensions_map[extended_type_name].append(_def)
78+
79+
# Below are functions used for producing this schema that have closed over
80+
# this scope and have access to the schema, cache, and newly defined types.
81+
82+
def get_type_from_def(type_def):
83+
type = _get_named_type(type_def.name)
84+
assert type, 'Invalid schema'
85+
return type
86+
87+
def get_type_from_AST(astNode):
88+
type = _get_named_type(astNode.name.value)
89+
if not type:
90+
raise GraphQLError(
91+
('Unknown type: "{}". Ensure that this type exists ' +
92+
'either in the original schema, or is added in a type definition.').format(
93+
astNode.name.value),
94+
[astNode]
95+
)
96+
return type
97+
98+
# Given a name, returns a type from either the existing schema or an
99+
# added type.
100+
def _get_named_type(typeName):
101+
cached_type_def = type_def_cache.get(typeName)
102+
if cached_type_def:
103+
return cached_type_def
104+
105+
existing_type = schema.get_type(typeName)
106+
if existing_type:
107+
type_def = extend_type(existing_type)
108+
type_def_cache[typeName] = type_def
109+
return type_def
110+
111+
type_ast = type_definition_map.get(typeName)
112+
if type_ast:
113+
type_def = build_type(type_ast)
114+
type_def_cache[typeName] = type_def
115+
return type_def
116+
117+
# Given a type's introspection result, construct the correct
118+
# GraphQLType instance.
119+
def extend_type(type):
120+
if isinstance(type, GraphQLObjectType):
121+
return extend_object_type(type)
122+
if isinstance(type, GraphQLInterfaceType):
123+
return extend_interface_type(type)
124+
if isinstance(type, GraphQLUnionType):
125+
return extend_union_type(type)
126+
return type
127+
128+
def extend_object_type(type):
129+
return GraphQLObjectType(
130+
name=type.name,
131+
description=type.description,
132+
interfaces=lambda: extend_implemented_interfaces(type),
133+
fields=lambda: extend_field_map(type),
134+
)
135+
136+
def extend_interface_type(type):
137+
return GraphQLInterfaceType(
138+
name=type.name,
139+
description=type.description,
140+
fields=lambda: extend_field_map(type),
141+
resolve_type=raise_client_schema_execution_error,
142+
)
143+
144+
def extend_union_type(type):
145+
return GraphQLUnionType(
146+
name=type.name,
147+
description=type.description,
148+
types=map(get_type_from_def, type.get_possible_types()),
149+
resolve_type=raise_client_schema_execution_error,
150+
)
151+
152+
def extend_implemented_interfaces(type):
153+
interfaces = map(get_type_from_def, type.get_interfaces())
154+
155+
# If there are any extensions to the interfaces, apply those here.
156+
extensions = type_extensions_map[type.name]
157+
for extension in extensions:
158+
for namedType in extension.definition.interfaces:
159+
interface_name = namedType.name.value
160+
if any([_def.name == interface_name for _def in interfaces]):
161+
raise GraphQLError(
162+
('Type "{}" already implements "{}". ' +
163+
'It cannot also be implemented in this type extension.').format(
164+
type.name, interface_name),
165+
[namedType]
166+
)
167+
interfaces.append(get_type_from_AST(namedType))
168+
169+
return interfaces
170+
171+
def extend_field_map(type):
172+
new_field_map = OrderedDict()
173+
old_field_map = type.get_fields()
174+
for field_name, field in old_field_map.iteritems():
175+
new_field_map[field_name] = GraphQLField(
176+
extend_field_type(field.type),
177+
description=field.description,
178+
deprecation_reason=field.deprecation_reason,
179+
args={arg.name: arg for arg in field.args},
180+
resolver=raise_client_schema_execution_error,
181+
)
182+
183+
# If there are any extensions to the fields, apply those here.
184+
extensions = type_extensions_map[type.name]
185+
for extension in extensions:
186+
for field in extension.definition.fields:
187+
field_name = field.name.value
188+
if field_name in old_field_map:
189+
raise GraphQLError(
190+
('Field "{}.{}" already exists in the ' +
191+
'schema. It cannot also be defined in this type extension.').format(
192+
type.name, field_name),
193+
[field]
194+
)
195+
new_field_map[field_name] = GraphQLField(
196+
build_field_type(field.type),
197+
args=build_input_values(field.arguments),
198+
resolver=raise_client_schema_execution_error,
199+
)
200+
201+
return new_field_map
202+
203+
def extend_field_type(type):
204+
if isinstance(type, GraphQLList):
205+
return GraphQLList(extend_field_type(type.of_type))
206+
if isinstance(type, GraphQLNonNull):
207+
return GraphQLNonNull(extend_field_type(type.of_type))
208+
return get_type_from_def(type)
209+
210+
def build_type(type_ast):
211+
_type_build = {
212+
ast.ObjectTypeDefinition: build_object_type,
213+
ast.InterfaceTypeDefinition: build_interface_type,
214+
ast.UnionTypeDefinition: build_union_type,
215+
ast.ScalarTypeDefinition: build_scalar_type,
216+
ast.EnumTypeDefinition: build_enum_type,
217+
ast.InputObjectTypeDefinition: build_input_object_type
218+
}
219+
func = _type_build.get(type(type_ast))
220+
if func:
221+
return func(type_ast)
222+
223+
def build_object_type(type_ast):
224+
return GraphQLObjectType(
225+
type_ast.name.value,
226+
interfaces=lambda: build_implemented_interfaces(type_ast),
227+
fields=lambda: build_field_map(type_ast),
228+
)
229+
230+
def build_interface_type(type_ast):
231+
return GraphQLInterfaceType(
232+
type_ast.name.value,
233+
fields=lambda: build_field_map(type_ast),
234+
resolve_type=raise_client_schema_execution_error,
235+
)
236+
237+
def build_union_type(type_ast):
238+
return GraphQLUnionType(
239+
type_ast.name.value,
240+
types=map(get_type_from_AST, type_ast.types),
241+
resolve_type=raise_client_schema_execution_error,
242+
)
243+
244+
def build_scalar_type(type_ast):
245+
return GraphQLScalarType(
246+
type_ast.name.value,
247+
serialize=lambda *args, **kwargs: None,
248+
# Note: validation calls the parse functions to determine if a
249+
# literal value is correct. Returning null would cause use of custom
250+
# scalars to always fail validation. Returning false causes them to
251+
# always pass validation.
252+
parse_value=lambda *args, **kwargs: False,
253+
parse_literal=lambda *args, **kwargs: False,
254+
)
255+
256+
def build_enum_type(type_ast):
257+
return GraphQLEnumType(
258+
type_ast.name.value,
259+
values={v.name.value: GraphQLEnumValue() for v in type_ast.values},
260+
)
261+
262+
def build_input_object_type(type_ast):
263+
return GraphQLInputObjectType(
264+
type_ast.name.value,
265+
fields=lambda: build_input_values(
266+
type_ast.fields, GraphQLInputObjectField),
267+
)
268+
269+
def build_implemented_interfaces(type_ast):
270+
return map(get_type_from_AST, type_ast.interfaces)
271+
272+
def build_field_map(type_ast):
273+
return {
274+
field.name.value: GraphQLField(
275+
build_field_type(field.type),
276+
args=build_input_values(field.arguments),
277+
resolver=raise_client_schema_execution_error,
278+
) for field in type_ast.fields
279+
}
280+
281+
def build_input_values(values, input_type=GraphQLArgument):
282+
input_values = OrderedDict()
283+
for value in values:
284+
type = build_field_type(value.type)
285+
input_values[value.name.value] = input_type(
286+
type,
287+
default_value=value_from_ast(value.default_value, type)
288+
)
289+
return input_values
290+
291+
def build_field_type(type_ast):
292+
if isinstance(type_ast, ast.ListType):
293+
return GraphQLList(build_field_type(type_ast.type))
294+
if isinstance(type_ast, ast.NonNullType):
295+
return GraphQLNonNull(build_field_type(type_ast.type))
296+
return get_type_from_AST(type_ast)
297+
298+
# If this document contains no new types, then return the same unmodified
299+
# GraphQLSchema instance.
300+
if not type_extensions_map and not type_definition_map:
301+
return schema
302+
303+
# A cache to use to store the actual GraphQLType definition objects by name.
304+
# Initialize to the GraphQL built in scalars. All functions below are inline
305+
# so that this type def cache is within the scope of the closure.
306+
type_def_cache = {
307+
'String': GraphQLString,
308+
'Int': GraphQLInt,
309+
'Float': GraphQLFloat,
310+
'Boolean': GraphQLBoolean,
311+
'ID': GraphQLID,
312+
}
313+
314+
# Get the root Query, Mutation, and Subscription types.
315+
query_type = get_type_from_def(schema.get_query_type())
316+
317+
existing_mutation_type = schema.get_mutation_type()
318+
mutationType = existing_mutation_type and get_type_from_def(
319+
existing_mutation_type) or None
320+
321+
existing_subscription_type = schema.get_subscription_type()
322+
subscription_type = existing_subscription_type and get_type_from_def(
323+
existing_subscription_type) or None
324+
325+
# Iterate through all types, getting the type definition for each, ensuring
326+
# that any type not directly referenced by a field will get created.
327+
for typeName, _def in schema.get_type_map().iteritems():
328+
get_type_from_def(_def)
329+
330+
# Do the same with new types.
331+
for typeName, _def in type_definition_map.iteritems():
332+
get_type_from_AST(_def)
333+
334+
# Then produce and return a Schema with these types.
335+
return GraphQLSchema(
336+
query=query_type,
337+
mutation=mutationType,
338+
subscription=subscription_type,
339+
# Copy directives.
340+
directives=schema.get_directives(),
341+
)
342+
343+
344+
def raise_client_schema_execution_error(*args, **kwargs):
345+
raise Exception('Client Schema cannot be used for execution.')

0 commit comments

Comments
 (0)