Skip to content

Commit 6687b16

Browse files
committed
Simplify get_operation_type_node function
Replicates graphql/graphql-js@02d59dc
1 parent 9228af4 commit 6687b16

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/graphql/type/validate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,25 +115,26 @@ def validate_root_types(self) -> None:
115115
elif not is_object_type(query_type):
116116
self.report_error(
117117
f"Query root type must be Object type, it cannot be {query_type}.",
118-
get_operation_type_node(schema, query_type, OperationType.QUERY),
118+
get_operation_type_node(schema, OperationType.QUERY)
119+
or query_type.ast_node,
119120
)
120121

121122
mutation_type = schema.mutation_type
122123
if mutation_type and not is_object_type(mutation_type):
123124
self.report_error(
124125
"Mutation root type must be Object type if provided,"
125126
f" it cannot be {mutation_type}.",
126-
get_operation_type_node(schema, mutation_type, OperationType.MUTATION),
127+
get_operation_type_node(schema, OperationType.MUTATION)
128+
or mutation_type.ast_node,
127129
)
128130

129131
subscription_type = schema.subscription_type
130132
if subscription_type and not is_object_type(subscription_type):
131133
self.report_error(
132134
"Subscription root type must be Object type if provided,"
133135
f" it cannot be {subscription_type}.",
134-
get_operation_type_node(
135-
schema, subscription_type, OperationType.SUBSCRIPTION
136-
),
136+
get_operation_type_node(schema, OperationType.SUBSCRIPTION)
137+
or subscription_type.ast_node,
137138
)
138139

139140
def validate_directives(self) -> None:
@@ -458,7 +459,7 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
458459

459460

460461
def get_operation_type_node(
461-
schema: GraphQLSchema, type_: GraphQLObjectType, operation: OperationType
462+
schema: GraphQLSchema, operation: OperationType
462463
) -> Optional[Node]:
463464
operation_nodes = cast(
464465
List[OperationTypeDefinitionNode],
@@ -467,7 +468,7 @@ def get_operation_type_node(
467468
for node in operation_nodes:
468469
if node.operation == operation:
469470
return node.type
470-
return type_.ast_node
471+
return None
471472

472473

473474
class InputObjectCircularRefsValidator:

0 commit comments

Comments
 (0)