Skip to content

Commit 8b820c3

Browse files
committed
validate_schema: use ast_node from fields/args instead of type's subnodes
Replicates graphql/graphql-js@27465b2
1 parent d0d0322 commit 8b820c3

File tree

1 file changed

+24
-104
lines changed

1 file changed

+24
-104
lines changed

graphql/type/validate.py

Lines changed: 24 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,7 @@
1515

1616
from ..error import GraphQLError
1717
from ..pyutils import inspect
18-
from ..language import (
19-
FieldDefinitionNode,
20-
InputValueDefinitionNode,
21-
NamedTypeNode,
22-
Node,
23-
OperationType,
24-
OperationTypeDefinitionNode,
25-
TypeNode,
26-
)
18+
from ..language import NamedTypeNode, Node, OperationType, OperationTypeDefinitionNode
2719
from .definition import (
2820
GraphQLEnumType,
2921
GraphQLInputField,
@@ -44,7 +36,7 @@
4436
)
4537
from ..utilities.assert_valid_name import is_valid_name_error
4638
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
47-
from .directives import GraphQLDirective, is_directive
39+
from .directives import is_directive, GraphQLDirective
4840
from .introspection import is_introspection_type
4941
from .schema import GraphQLSchema, assert_schema
5042

@@ -170,7 +162,12 @@ def validate_directives(self):
170162
self.report_error(
171163
f"Argument @{directive.name}({arg_name}:)"
172164
" can only be defined once.",
173-
get_all_directive_arg_nodes(directive, arg_name),
165+
directive.ast_node
166+
and [
167+
arg.ast_node
168+
for name, arg in directive.args.items()
169+
if name == arg_name
170+
],
174171
)
175172
continue
176173
arg_names.add(arg_name)
@@ -180,7 +177,7 @@ def validate_directives(self):
180177
self.report_error(
181178
f"The type of @{directive.name}({arg_name}:)"
182179
f" must be Input Type but got: {inspect(arg.type)}.",
183-
get_directive_arg_type_node(directive, arg_name),
180+
arg.ast_node,
184181
)
185182

186183
def validate_name(self, node: Any, name: str = None):
@@ -260,7 +257,7 @@ def validate_fields(self, type_: Union[GraphQLObjectType, GraphQLInterfaceType])
260257
self.report_error(
261258
f"The type of {type_.name}.{field_name}"
262259
" must be Output Type but got: {inspect(field.type)}.",
263-
get_field_type_node(type_, field_name),
260+
field.ast_node and field.ast_node.type,
264261
)
265262

266263
# Ensure the arguments are valid.
@@ -275,7 +272,11 @@ def validate_fields(self, type_: Union[GraphQLObjectType, GraphQLInterfaceType])
275272
"Field argument"
276273
f" {type_.name}.{field_name}({arg_name}:)"
277274
" can only be defined once.",
278-
get_all_field_arg_nodes(type_, field_name, arg_name),
275+
[
276+
arg.ast_node
277+
for name, arg in field.args.items()
278+
if name == arg_name
279+
],
279280
)
280281
break
281282
arg_names.add(arg_name)
@@ -286,7 +287,7 @@ def validate_fields(self, type_: Union[GraphQLObjectType, GraphQLInterfaceType])
286287
"Field argument"
287288
f" {type_.name}.{field_name}({arg_name}:)"
288289
f" must be Input Type but got: {inspect(arg.type)}.",
289-
get_field_arg_type_node(type_, field_name, arg_name),
290+
arg.ast_node and arg.ast_node.type,
290291
)
291292

292293
def validate_object_interfaces(self, obj: GraphQLObjectType):
@@ -296,7 +297,7 @@ def validate_object_interfaces(self, obj: GraphQLObjectType):
296297
self.report_error(
297298
f"Type {obj.name} must only implement Interface"
298299
f" types, it cannot implement {inspect(iface)}.",
299-
get_implements_interface_node(obj, iface),
300+
get_all_implements_interface_nodes(obj, iface),
300301
)
301302
continue
302303
if iface.name in implemented_type_names:
@@ -322,8 +323,7 @@ def validate_object_implements_interface(
322323
self.report_error(
323324
f"Interface field {iface.name}.{field_name}"
324325
f" expected but {obj.name} does not provide it.",
325-
[get_field_node(iface, field_name)]
326-
+ cast(List[Optional[FieldDefinitionNode]], get_all_nodes(obj)),
326+
[iface_field.ast_node, *get_all_nodes(obj)],
327327
)
328328
continue
329329

@@ -336,8 +336,8 @@ def validate_object_implements_interface(
336336
f" but {obj.name}.{field_name}"
337337
f" is type {obj_field.type}.",
338338
[
339-
get_field_type_node(iface, field_name),
340-
get_field_type_node(obj, field_name),
339+
iface_field.ast_node and iface_field.ast_node.type,
340+
obj_field.ast_node and obj_field.ast_node.type,
341341
],
342342
)
343343

@@ -352,10 +352,7 @@ def validate_object_implements_interface(
352352
f" {iface.name}.{field_name}({arg_name}:)"
353353
f" expected but {obj.name}.{field_name}"
354354
" does not provide it.",
355-
[
356-
get_field_arg_node(iface, field_name, arg_name),
357-
get_field_node(obj, field_name),
358-
],
355+
[iface_arg.ast_node, obj_field.ast_node],
359356
)
360357
continue
361358

@@ -369,8 +366,8 @@ def validate_object_implements_interface(
369366
f" but {obj.name}.{field_name}({arg_name}:)"
370367
f" is type {obj_arg.type}.",
371368
[
372-
get_field_arg_type_node(iface, field_name, arg_name),
373-
get_field_arg_type_node(obj, field_name, arg_name),
369+
iface_arg.ast_node and iface_arg.ast_node.type,
370+
obj_arg.ast_node and obj_arg.ast_node.type,
374371
],
375372
)
376373

@@ -382,10 +379,7 @@ def validate_object_implements_interface(
382379
f"Object field {obj.name}.{field_name} includes"
383380
f" required argument {arg_name} that is missing from"
384381
f" the Interface field {iface.name}.{field_name}.",
385-
[
386-
get_field_arg_node(obj, field_name, arg_name),
387-
get_field_node(iface, field_name),
388-
],
382+
[obj_arg.ast_node, iface_field.ast_node],
389383
)
390384

391385
def validate_union_members(self, union: GraphQLUnionType):
@@ -548,13 +542,6 @@ def get_all_sub_nodes(
548542
return result
549543

550544

551-
def get_implements_interface_node(
552-
type_: GraphQLObjectType, iface: GraphQLInterfaceType
553-
) -> Optional[NamedTypeNode]:
554-
nodes = get_all_implements_interface_nodes(type_, iface)
555-
return nodes[0] if nodes else None
556-
557-
558545
def get_all_implements_interface_nodes(
559546
type_: GraphQLObjectType, iface: GraphQLInterfaceType
560547
) -> List[NamedTypeNode]:
@@ -568,73 +555,6 @@ def get_all_implements_interface_nodes(
568555
]
569556

570557

571-
def get_field_node(
572-
type_: Union[GraphQLObjectType, GraphQLInterfaceType], field_name: str
573-
) -> Optional[FieldDefinitionNode]:
574-
all_field_nodes = filter(
575-
lambda field_node: field_node.name.value == field_name,
576-
cast(List[FieldDefinitionNode], get_all_sub_nodes(type_, attrgetter("fields"))),
577-
)
578-
return next(all_field_nodes, None)
579-
580-
581-
def get_field_type_node(
582-
type_: Union[GraphQLObjectType, GraphQLInterfaceType], field_name: str
583-
) -> Optional[TypeNode]:
584-
field_node = get_field_node(type_, field_name)
585-
return field_node.type if field_node else None
586-
587-
588-
def get_field_arg_node(
589-
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
590-
field_name: str,
591-
arg_name: str,
592-
) -> Optional[InputValueDefinitionNode]:
593-
nodes = get_all_field_arg_nodes(type_, field_name, arg_name)
594-
return nodes[0] if nodes else None
595-
596-
597-
def get_all_field_arg_nodes(
598-
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
599-
field_name: str,
600-
arg_name: str,
601-
) -> List[InputValueDefinitionNode]:
602-
arg_nodes = []
603-
field_node = get_field_node(type_, field_name)
604-
if field_node and field_node.arguments:
605-
for node in field_node.arguments:
606-
if node.name.value == arg_name:
607-
arg_nodes.append(node)
608-
return arg_nodes
609-
610-
611-
def get_field_arg_type_node(
612-
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
613-
field_name: str,
614-
arg_name: str,
615-
) -> Optional[TypeNode]:
616-
field_arg_node = get_field_arg_node(type_, field_name, arg_name)
617-
return field_arg_node.type if field_arg_node else None
618-
619-
620-
def get_all_directive_arg_nodes(
621-
directive: GraphQLDirective, arg_name: str
622-
) -> List[InputValueDefinitionNode]:
623-
arg_nodes = cast(
624-
List[InputValueDefinitionNode],
625-
get_all_sub_nodes(directive, attrgetter("arguments")),
626-
)
627-
return [arg_node for arg_node in arg_nodes if arg_node.name.value == arg_name]
628-
629-
630-
def get_directive_arg_type_node(
631-
directive: GraphQLDirective, arg_name: str
632-
) -> Optional[TypeNode]:
633-
arg_nodes = get_all_directive_arg_nodes(directive, arg_name)
634-
arg_node = arg_nodes[0] if arg_nodes else None
635-
return arg_node.type if arg_node else None
636-
637-
638558
def get_union_member_type_nodes(
639559
union: GraphQLUnionType, type_name: str
640560
) -> Optional[List[NamedTypeNode]]:

0 commit comments

Comments
 (0)