Skip to content

Commit 81d7b0f

Browse files
committed
Add typing for print_ast function
Replicates graphql/graphql-js@213c8a7
1 parent e075bec commit 81d7b0f

File tree

6 files changed

+59
-46
lines changed

6 files changed

+59
-46
lines changed

graphql/language/printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
__all__ = ["print_ast"]
99

1010

11-
def print_ast(ast: Node):
11+
def print_ast(ast: Node) -> str:
1212
"""Convert an AST into a string.
1313
1414
The conversion is done using a set of reasonable formatting rules.

graphql/language/visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class Stack(NamedTuple):
190190
prev: Any # 'Stack' (python/mypy/issues/731)
191191

192192

193-
def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Node:
193+
def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
194194
"""Visit each node in an AST.
195195
196196
`visit()` will walk through an AST using a depth first traversal, calling the

graphql/type/introspection.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
GraphQLEnumType,
77
GraphQLEnumValue,
88
GraphQLField,
9-
GraphQLInputType,
109
GraphQLList,
1110
GraphQLNonNull,
1211
GraphQLObjectType,
@@ -21,9 +20,8 @@
2120
is_scalar_type,
2221
is_union_type,
2322
)
24-
from ..pyutils import is_invalid
2523
from .scalars import GraphQLBoolean, GraphQLString
26-
from ..language import DirectiveLocation
24+
from ..language import DirectiveLocation, print_ast
2725

2826
__all__ = [
2927
"SchemaMetaFieldDef",
@@ -35,13 +33,6 @@
3533
]
3634

3735

38-
def print_value(value: Any, type_: GraphQLInputType) -> str:
39-
# Since print_value needs graphql.type, it can only be imported later
40-
from ..utilities.schema_printer import print_value
41-
42-
return print_value(value, type_)
43-
44-
4536
__Schema: GraphQLObjectType = GraphQLObjectType(
4637
name="__Schema",
4738
description="A GraphQL Schema defines the capabilities of a GraphQL"
@@ -345,6 +336,14 @@ def of_type(type_, _info):
345336
)
346337

347338

339+
def _resolve_input_value_default_value(item, _info):
340+
# Since ast_from_value needs graphql.type, it can only be imported later
341+
from ..utilities import ast_from_value
342+
343+
value_ast = ast_from_value(item[1].default_value, item[1].type)
344+
return print_ast(value_ast) if value_ast else None
345+
346+
348347
__InputValue: GraphQLObjectType = GraphQLObjectType(
349348
name="__InputValue",
350349
description="Arguments provided to Fields or Directives and the input"
@@ -364,9 +363,7 @@ def of_type(type_, _info):
364363
GraphQLString,
365364
description="A GraphQL-formatted string representing"
366365
" the default value for this input value.",
367-
resolve=lambda item, _info: None
368-
if is_invalid(item[1].default_value)
369-
else print_value(item[1].default_value, item[1].type),
366+
resolve=_resolve_input_value_default_value,
370367
),
371368
},
372369
)

graphql/utilities/schema_printer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Dict, List, Optional, Union, cast
44

55
from ..language import print_ast
6-
from ..pyutils import inspect, is_invalid, is_nullish
6+
from ..pyutils import inspect
77
from ..type import (
88
DEFAULT_DEPRECATION_REASON,
99
GraphQLArgument,
@@ -248,9 +248,10 @@ def print_args(args: Dict[str, GraphQLArgument], indentation="") -> str:
248248

249249

250250
def print_input_value(name: str, arg: GraphQLArgument) -> str:
251+
default_ast = ast_from_value(arg.default_value, arg.type)
251252
arg_decl = f"{name}: {arg.type}"
252-
if not is_invalid(arg.default_value):
253-
arg_decl += f" = {print_value(arg.default_value, arg.type)}"
253+
if default_ast:
254+
arg_decl += f" = {print_ast(default_ast)}"
254255
return arg_decl
255256

256257

@@ -268,10 +269,10 @@ def print_deprecated(field_or_enum_value: Union[GraphQLField, GraphQLEnumValue])
268269
if not field_or_enum_value.is_deprecated:
269270
return ""
270271
reason = field_or_enum_value.deprecation_reason
271-
if is_nullish(reason) or reason == "" or reason == DEFAULT_DEPRECATION_REASON:
272+
reason_ast = ast_from_value(reason, GraphQLString)
273+
if not reason_ast or reason == "" or reason == DEFAULT_DEPRECATION_REASON:
272274
return " @deprecated"
273-
else:
274-
return f" @deprecated(reason: {print_value(reason, GraphQLString)})"
275+
return f" @deprecated(reason: {print_ast(reason_ast)})"
275276

276277

277278
def print_description(

tests/utilities/test_build_ast_schema.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytest import raises
44

55
from graphql import graphql_sync
6-
from graphql.language import parse, print_ast, DocumentNode
6+
from graphql.language import parse, print_ast, DocumentNode, Node
77
from graphql.type import (
88
GraphQLDeprecatedDirective,
99
GraphQLIncludeDirective,
@@ -33,6 +33,11 @@ def cycle_sdl(sdl: str) -> str:
3333
return print_schema(schema)
3434

3535

36+
def print_node(node: Node) -> str:
37+
assert node
38+
return print_ast(node)
39+
40+
3641
def describe_schema_builder():
3742
def can_use_built_schema_for_limited_execution():
3843
schema = build_ast_schema(
@@ -704,7 +709,9 @@ def correctly_assign_ast_nodes():
704709
directive @test(arg: TestScalar) on FIELD
705710
"""
706711
)
707-
schema = build_schema(sdl)
712+
ast = parse(sdl, no_location=True)
713+
714+
schema = build_ast_schema(ast)
708715
query = assert_object_type(schema.get_type("Query"))
709716
test_input = assert_input_object_type(schema.get_type("TestInput"))
710717
test_enum = assert_enum_type(schema.get_type("TestEnum"))
@@ -725,22 +732,23 @@ def correctly_assign_ast_nodes():
725732
test_type.ast_node,
726733
test_scalar.ast_node,
727734
test_directive.ast_node,
728-
]
735+
],
736+
loc=None,
729737
)
730-
assert print_ast(restored_schema_ast) == sdl
738+
assert restored_schema_ast == ast
731739

732740
test_field = query.fields["testField"]
733-
assert print_ast(test_field.ast_node) == (
741+
assert print_node(test_field.ast_node) == (
734742
"testField(testArg: TestInput): TestUnion"
735743
)
736-
assert print_ast(test_field.args["testArg"].ast_node) == "testArg: TestInput"
737-
assert print_ast(test_input.fields["testInputField"].ast_node) == (
744+
assert print_node(test_field.args["testArg"].ast_node) == "testArg: TestInput"
745+
assert print_node(test_input.fields["testInputField"].ast_node) == (
738746
"testInputField: TestEnum"
739747
)
740748
test_enum_value = test_enum.values["TEST_VALUE"]
741749
assert test_enum_value
742-
assert print_ast(test_enum_value.ast_node) == "TEST_VALUE"
743-
assert print_ast(test_interface.fields["interfaceField"].ast_node) == (
750+
assert print_node(test_enum_value.ast_node) == "TEST_VALUE"
751+
assert print_node(test_interface.fields["interfaceField"].ast_node) == (
744752
"interfaceField: String"
745753
)
746754
assert print_ast(test_directive.args["arg"].ast_node) == "arg: TestScalar"

tests/utilities/test_extend_schema.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pytest import raises
22

33
from graphql import graphql_sync
4-
from graphql.language import parse, print_ast, DirectiveLocation, DocumentNode
4+
from graphql.language import parse, print_ast, DirectiveLocation, DocumentNode, Node
55
from graphql.pyutils import dedent
66
from graphql.type import (
77
GraphQLArgument,
@@ -140,6 +140,11 @@ def print_test_schema_changes(extended_schema):
140140
return print_ast(ast)
141141

142142

143+
def print_node(node: Node) -> str:
144+
assert node
145+
return print_ast(node)
146+
147+
143148
def describe_extend_schema():
144149
def returns_the_original_schema_when_there_are_no_type_definitions():
145150
extended_schema = extend_test_schema("{ field }")
@@ -423,51 +428,53 @@ def correctly_assigns_ast_nodes_to_new_and_extended_types():
423428
) == print_schema(extended_twice_schema)
424429

425430
new_field = query.fields["newField"]
426-
assert print_ast(new_field.ast_node) == "newField(testArg: TestInput): TestEnum"
427-
assert print_ast(new_field.args["testArg"].ast_node) == "testArg: TestInput"
428431
assert (
429-
print_ast(query.fields["oneMoreNewField"].ast_node)
432+
print_node(new_field.ast_node) == "newField(testArg: TestInput): TestEnum"
433+
)
434+
assert print_node(new_field.args["testArg"].ast_node) == "testArg: TestInput"
435+
assert (
436+
print_node(query.fields["oneMoreNewField"].ast_node)
430437
== "oneMoreNewField: TestUnion"
431438
)
432439

433440
new_value = some_enum.values["NEW_VALUE"]
434441
assert some_enum
435-
assert print_ast(new_value.ast_node) == "NEW_VALUE"
442+
assert print_node(new_value.ast_node) == "NEW_VALUE"
436443

437444
one_more_new_value = some_enum.values["ONE_MORE_NEW_VALUE"]
438445
assert one_more_new_value
439-
assert print_ast(one_more_new_value.ast_node) == "ONE_MORE_NEW_VALUE"
440-
assert print_ast(some_input.fields["newField"].ast_node) == "newField: String"
446+
assert print_node(one_more_new_value.ast_node) == "ONE_MORE_NEW_VALUE"
447+
assert print_node(some_input.fields["newField"].ast_node) == "newField: String"
441448
assert (
442-
print_ast(some_input.fields["oneMoreNewField"].ast_node)
449+
print_node(some_input.fields["oneMoreNewField"].ast_node)
443450
== "oneMoreNewField: String"
444451
)
445452
assert (
446-
print_ast(some_interface.fields["newField"].ast_node) == "newField: String"
453+
print_node(some_interface.fields["newField"].ast_node) == "newField: String"
447454
)
448455
assert (
449-
print_ast(some_interface.fields["oneMoreNewField"].ast_node)
456+
print_node(some_interface.fields["oneMoreNewField"].ast_node)
450457
== "oneMoreNewField: String"
451458
)
452459

453460
assert (
454-
print_ast(test_input.fields["testInputField"].ast_node)
461+
print_node(test_input.fields["testInputField"].ast_node)
455462
== "testInputField: TestEnum"
456463
)
457464

458465
test_value = test_enum.values["TEST_VALUE"]
459466
assert test_value
460-
assert print_ast(test_value.ast_node) == "TEST_VALUE"
467+
assert print_node(test_value.ast_node) == "TEST_VALUE"
461468

462469
assert (
463-
print_ast(test_interface.fields["interfaceField"].ast_node)
470+
print_node(test_interface.fields["interfaceField"].ast_node)
464471
== "interfaceField: String"
465472
)
466473
assert (
467-
print_ast(test_type.fields["interfaceField"].ast_node)
474+
print_node(test_type.fields["interfaceField"].ast_node)
468475
== "interfaceField: String"
469476
)
470-
assert print_ast(test_directive.args["arg"].ast_node) == "arg: Int"
477+
assert print_node(test_directive.args["arg"].ast_node) == "arg: Int"
471478

472479
def builds_types_with_deprecated_fields_and_values():
473480
extended_schema = extend_test_schema(
@@ -1092,7 +1099,7 @@ def adds_schema_definition_missing_in_the_original_schema():
10921099
schema = extend_schema(schema, parse(extension_sdl))
10931100
query_type = schema.query_type
10941101
assert query_type.name == "Foo"
1095-
assert print_ast(schema.ast_node) == extension_sdl.rstrip()
1102+
assert print_node(schema.ast_node) == extension_sdl.rstrip()
10961103

10971104
def adds_new_root_types_via_schema_extension():
10981105
schema = extend_test_schema(

0 commit comments

Comments
 (0)