Skip to content

Commit b7a18ed

Browse files
committed
Implement OneOf Input Objects via @OneOf directive
Replicates graphql/graphql-js@8cfa3de
1 parent 6e6d5be commit b7a18ed

23 files changed

+720
-13
lines changed

src/graphql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@
259259
GraphQLStreamDirective,
260260
GraphQLDeprecatedDirective,
261261
GraphQLSpecifiedByDirective,
262+
GraphQLOneOfDirective,
262263
# "Enum" of Type Kinds
263264
TypeKind,
264265
# Constant Deprecation Reason
@@ -504,6 +505,7 @@
504505
"GraphQLStreamDirective",
505506
"GraphQLDeprecatedDirective",
506507
"GraphQLSpecifiedByDirective",
508+
"GraphQLOneOfDirective",
507509
"TypeKind",
508510
"DEFAULT_DEPRECATION_REASON",
509511
"introspection_types",

src/graphql/execution/values.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,20 @@ def coerce_variable_values(
128128
continue
129129

130130
def on_input_value_error(
131-
path: list[str | int], invalid_value: Any, error: GraphQLError
131+
path: list[str | int],
132+
invalid_value: Any,
133+
error: GraphQLError,
134+
var_name: str = var_name,
135+
var_def_node: VariableDefinitionNode = var_def_node,
132136
) -> None:
133137
invalid_str = inspect(invalid_value)
134-
prefix = f"Variable '${var_name}' got invalid value {invalid_str}" # noqa: B023
138+
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
135139
if path:
136-
prefix += f" at '{var_name}{print_path_list(path)}'" # noqa: B023
140+
prefix += f" at '{var_name}{print_path_list(path)}'"
137141
on_error(
138142
GraphQLError(
139143
prefix + "; " + error.message,
140-
var_def_node, # noqa: B023
144+
var_def_node,
141145
original_error=error,
142146
)
143147
)
@@ -193,7 +197,8 @@ def get_argument_values(
193197
)
194198
raise GraphQLError(msg, value_node)
195199
continue # pragma: no cover
196-
is_null = variable_values[variable_name] is None
200+
variable_value = variable_values[variable_name]
201+
is_null = variable_value is None or variable_value is Undefined
197202

198203
if is_null and is_non_null_type(arg_type):
199204
msg = f"Argument '{name}' of non-null type '{arg_type}' must not be null."

src/graphql/type/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
GraphQLStreamDirective,
138138
GraphQLDeprecatedDirective,
139139
GraphQLSpecifiedByDirective,
140+
GraphQLOneOfDirective,
140141
# Keyword Args
141142
GraphQLDirectiveKwargs,
142143
# Constant Deprecation Reason
@@ -286,6 +287,7 @@
286287
"GraphQLStreamDirective",
287288
"GraphQLDeprecatedDirective",
288289
"GraphQLSpecifiedByDirective",
290+
"GraphQLOneOfDirective",
289291
"GraphQLDirectiveKwargs",
290292
"DEFAULT_DEPRECATION_REASON",
291293
"is_specified_scalar_type",

src/graphql/type/definition.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,7 @@ class GraphQLInputObjectTypeKwargs(GraphQLNamedTypeKwargs, total=False):
12721272

12731273
fields: GraphQLInputFieldMap
12741274
out_type: GraphQLInputFieldOutType | None
1275+
is_one_of: bool
12751276

12761277

12771278
class GraphQLInputObjectType(GraphQLNamedType):
@@ -1301,6 +1302,7 @@ class GeoPoint(GraphQLInputObjectType):
13011302

13021303
ast_node: InputObjectTypeDefinitionNode | None
13031304
extension_ast_nodes: tuple[InputObjectTypeExtensionNode, ...]
1305+
is_one_of: bool
13041306

13051307
def __init__(
13061308
self,
@@ -1311,6 +1313,7 @@ def __init__(
13111313
extensions: dict[str, Any] | None = None,
13121314
ast_node: InputObjectTypeDefinitionNode | None = None,
13131315
extension_ast_nodes: Collection[InputObjectTypeExtensionNode] | None = None,
1316+
is_one_of: bool = False,
13141317
) -> None:
13151318
super().__init__(
13161319
name=name,
@@ -1322,6 +1325,7 @@ def __init__(
13221325
self._fields = fields
13231326
if out_type is not None:
13241327
self.out_type = out_type # type: ignore
1328+
self.is_one_of = is_one_of
13251329

13261330
@staticmethod
13271331
def out_type(value: dict[str, Any]) -> Any:
@@ -1340,6 +1344,7 @@ def to_kwargs(self) -> GraphQLInputObjectTypeKwargs:
13401344
out_type=None
13411345
if self.out_type is GraphQLInputObjectType.out_type
13421346
else self.out_type,
1347+
is_one_of=self.is_one_of,
13431348
)
13441349

13451350
def __copy__(self) -> GraphQLInputObjectType: # pragma: no cover

src/graphql/type/directives.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,20 @@ def assert_directive(directive: Any) -> GraphQLDirective:
261261
description="Exposes a URL that specifies the behaviour of this scalar.",
262262
)
263263

264+
# Used to declare an Input Object as a OneOf Input Objects.
265+
GraphQLOneOfDirective = GraphQLDirective(
266+
name="oneOf",
267+
locations=[DirectiveLocation.INPUT_OBJECT],
268+
args={},
269+
description="Indicates an Input Object is a OneOf Input Object.",
270+
)
271+
264272
specified_directives: tuple[GraphQLDirective, ...] = (
265273
GraphQLIncludeDirective,
266274
GraphQLSkipDirective,
267275
GraphQLDeprecatedDirective,
268276
GraphQLSpecifiedByDirective,
277+
GraphQLOneOfDirective,
269278
)
270279
"""A tuple with all directives from the GraphQL specification"""
271280

src/graphql/type/introspection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __new__(cls):
305305
resolve=cls.input_fields,
306306
),
307307
"ofType": GraphQLField(_Type, resolve=cls.of_type),
308+
"isOneOf": GraphQLField(GraphQLBoolean, resolve=cls.is_one_of),
308309
}
309310

310311
@staticmethod
@@ -396,6 +397,10 @@ def input_fields(type_, _info, includeDeprecated=False):
396397
def of_type(type_, _info):
397398
return getattr(type_, "of_type", None)
398399

400+
@staticmethod
401+
def is_one_of(type_, _info):
402+
return type_.is_one_of if is_input_object_type(type_) else None
403+
399404

400405
_Type: GraphQLObjectType = GraphQLObjectType(
401406
name="__Type",

src/graphql/type/validate.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
SchemaDefinitionNode,
1717
SchemaExtensionNode,
1818
)
19-
from ..pyutils import and_list, inspect
19+
from ..pyutils import Undefined, and_list, inspect
2020
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
2121
from .definition import (
2222
GraphQLEnumType,
@@ -482,6 +482,28 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
482482
],
483483
)
484484

485+
if input_obj.is_one_of:
486+
self.validate_one_of_input_object_field(input_obj, field_name, field)
487+
488+
def validate_one_of_input_object_field(
489+
self,
490+
type_: GraphQLInputObjectType,
491+
field_name: str,
492+
field: GraphQLInputField,
493+
) -> None:
494+
if is_non_null_type(field.type):
495+
self.report_error(
496+
f"OneOf input field {type_.name}.{field_name} must be nullable.",
497+
field.ast_node and field.ast_node.type,
498+
)
499+
500+
if field.default_value is not Undefined:
501+
self.report_error(
502+
f"OneOf input field {type_.name}.{field_name}"
503+
" cannot have a default value.",
504+
field.ast_node,
505+
)
506+
485507

486508
def get_operation_type_node(
487509
schema: GraphQLSchema, operation: OperationType

src/graphql/utilities/coerce_input_value.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,30 @@ def coerce_input_value(
130130
+ did_you_mean(suggestions)
131131
),
132132
)
133+
134+
if type_.is_one_of:
135+
keys = list(coerced_dict)
136+
if len(keys) != 1:
137+
on_error(
138+
path.as_list() if path else [],
139+
input_value,
140+
GraphQLError(
141+
"Exactly one key must be specified"
142+
f" for OneOf type '{type_.name}'.",
143+
),
144+
)
145+
else:
146+
key = keys[0]
147+
value = coerced_dict[key]
148+
if value is None:
149+
on_error(
150+
(path.as_list() if path else []) + [key],
151+
value,
152+
GraphQLError(
153+
f"Field '{key}' must be non-null.",
154+
),
155+
)
156+
133157
return type_.out_type(coerced_dict)
134158

135159
if is_leaf_type(type_):

src/graphql/utilities/extend_schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
GraphQLNullableType,
6666
GraphQLObjectType,
6767
GraphQLObjectTypeKwargs,
68+
GraphQLOneOfDirective,
6869
GraphQLOutputType,
6970
GraphQLScalarType,
7071
GraphQLSchema,
@@ -777,6 +778,7 @@ def build_input_object_type(
777778
fields=partial(self.build_input_field_map, all_nodes),
778779
ast_node=ast_node,
779780
extension_ast_nodes=extension_nodes,
781+
is_one_of=is_one_of(ast_node),
780782
)
781783

782784
def build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
@@ -822,3 +824,10 @@ def get_specified_by_url(
822824

823825
specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node)
824826
return specified_by_url["url"] if specified_by_url else None
827+
828+
829+
def is_one_of(node: InputObjectTypeDefinitionNode) -> bool:
830+
"""Given an input object node, returns if the node should be OneOf."""
831+
from ..execution import get_directive_values
832+
833+
return get_directive_values(GraphQLOneOfDirective, node) is not None

src/graphql/utilities/value_from_ast.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ def value_from_ast(
118118
return Undefined
119119
coerced_obj[field.out_name or field_name] = field_value
120120

121+
if type_.is_one_of:
122+
keys = list(coerced_obj)
123+
if len(keys) != 1:
124+
return Undefined
125+
126+
if coerced_obj[keys[0]] is None:
127+
return Undefined
128+
121129
return type_.out_type(coerced_obj)
122130

123131
if is_leaf_type(type_):

0 commit comments

Comments
 (0)