Skip to content

Commit 4d3b782

Browse files
committed
Add isRequiredArgument and isRequiredInputField predicates
Replicates graphql/graphql-js@36dc149
1 parent 07c3a08 commit 4d3b782

File tree

5 files changed

+23
-12
lines changed

5 files changed

+23
-12
lines changed

graphql/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@
9999
is_wrapping_type,
100100
is_nullable_type,
101101
is_named_type,
102+
is_required_argument,
103+
is_required_input_field,
102104
is_specified_scalar_type,
103105
is_introspection_type,
104106
is_specified_directive,
@@ -375,6 +377,7 @@
375377
'is_list_type', 'is_non_null_type', 'is_input_type', 'is_output_type',
376378
'is_leaf_type', 'is_composite_type', 'is_abstract_type',
377379
'is_wrapping_type', 'is_nullable_type', 'is_named_type',
380+
'is_required_argument', 'is_required_input_field',
378381
'is_specified_scalar_type', 'is_introspection_type',
379382
'is_specified_directive',
380383
'assert_type', 'assert_scalar_type', 'assert_object_type',

graphql/type/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_non_null_type, is_input_type, is_output_type, is_leaf_type,
1818
is_composite_type, is_abstract_type, is_wrapping_type,
1919
is_nullable_type, is_named_type,
20+
is_required_argument, is_required_input_field,
2021
# Assertions
2122
assert_type, assert_scalar_type, assert_object_type,
2223
assert_interface_type, assert_union_type, assert_enum_type,
@@ -82,6 +83,7 @@
8283
'is_non_null_type', 'is_input_type', 'is_output_type', 'is_leaf_type',
8384
'is_composite_type', 'is_abstract_type', 'is_wrapping_type',
8485
'is_nullable_type', 'is_named_type',
86+
'is_required_argument', 'is_required_input_field',
8587
'assert_type', 'assert_scalar_type', 'assert_object_type',
8688
'assert_interface_type', 'assert_union_type', 'assert_enum_type',
8789
'assert_input_object_type', 'assert_list_type', 'assert_non_null_type',

graphql/type/definition.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'is_non_null_type', 'is_input_type', 'is_output_type', 'is_leaf_type',
2828
'is_composite_type', 'is_abstract_type', 'is_wrapping_type',
2929
'is_nullable_type', 'is_named_type',
30+
'is_required_argument', 'is_required_input_field',
3031
'assert_type', 'assert_scalar_type', 'assert_object_type',
3132
'assert_interface_type', 'assert_union_type', 'assert_enum_type',
3233
'assert_input_object_type', 'assert_list_type', 'assert_non_null_type',
@@ -415,6 +416,10 @@ def __eq__(self, other):
415416
self.description == other.description))
416417

417418

419+
def is_required_argument(arg: GraphQLArgument) -> bool:
420+
return is_non_null_type(arg.type) and arg.default_value is INVALID
421+
422+
418423
T = TypeVar('T')
419424
Thunk = Union[Callable[[], T], T]
420425

@@ -981,6 +986,10 @@ def __eq__(self, other):
981986
self.description == other.description))
982987

983988

989+
def is_required_input_field(field: GraphQLInputField) -> bool:
990+
return is_non_null_type(field.type) and field.default_value is INVALID
991+
992+
984993
# Wrapper types
985994

986995
class GraphQLList(Generic[GT], GraphQLWrappingType[GT]):

graphql/validation/rules/provided_required_arguments.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from ...error import GraphQLError, INVALID
1+
from ...error import GraphQLError
22
from ...language import DirectiveNode, FieldNode
3-
from ...type import is_non_null_type
3+
from ...type import is_required_argument
44
from . import ValidationRule
55

66
__all__ = [
@@ -37,8 +37,7 @@ def leave_field(self, node: FieldNode, *_args):
3737
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
3838
for arg_name, arg_def in field_def.args.items():
3939
arg_node = arg_node_map.get(arg_name)
40-
if not arg_node and is_non_null_type(
41-
arg_def.type) and arg_def.default_value is INVALID:
40+
if not arg_node and is_required_argument(arg_def):
4241
self.report_error(GraphQLError(missing_field_arg_message(
4342
node.name.value, arg_name, str(arg_def.type)), [node]))
4443

@@ -52,7 +51,6 @@ def leave_directive(self, node: DirectiveNode, *_args):
5251
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
5352
for arg_name, arg_def in directive_def.args.items():
5453
arg_node = arg_node_map.get(arg_name)
55-
if not arg_node and is_non_null_type(
56-
arg_def.type) and arg_def.default_value is INVALID:
54+
if not arg_node and is_required_argument(arg_def):
5755
self.report_error(GraphQLError(missing_directive_arg_message(
5856
node.name.value, arg_name, str(arg_def.type)), [node]))

graphql/validation/rules/values_of_correct_type.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Optional, cast
22

3-
from ...error import GraphQLError, INVALID
3+
from ...error import GraphQLError
44
from ...language import (
55
BooleanValueNode, EnumValueNode, FloatValueNode, IntValueNode,
66
NullValueNode, ListValueNode, ObjectFieldNode, ObjectValueNode,
@@ -9,7 +9,7 @@
99
from ...type import (
1010
GraphQLEnumType, GraphQLScalarType, GraphQLType,
1111
get_named_type, get_nullable_type, is_enum_type, is_input_object_type,
12-
is_list_type, is_non_null_type, is_scalar_type)
12+
is_list_type, is_non_null_type, is_required_input_field, is_scalar_type)
1313
from . import ValidationRule
1414

1515
__all__ = [
@@ -65,12 +65,11 @@ def enter_object_value(self, node: ObjectValueNode, *_args):
6565
input_fields = type_.fields
6666
field_node_map = {field.name.value: field for field in node.fields}
6767
for field_name, field_def in input_fields.items():
68-
field_type = field_def.type
6968
field_node = field_node_map.get(field_name)
70-
if not field_node and is_non_null_type(
71-
field_type) and field_def.default_value is INVALID:
69+
if not field_node and is_required_input_field(field_def):
70+
field_type = field_def.type
7271
self.report_error(GraphQLError(required_field_message(
73-
type_.name, field_name, field_type), node))
72+
type_.name, field_name, str(field_type)), node))
7473

7574
def enter_object_field(self, node: ObjectFieldNode, *_args):
7675
parent_type = get_named_type(self.context.get_parent_input_type())

0 commit comments

Comments
 (0)