|
| 1 | +from typing import Union |
| 2 | + |
1 | 3 | from pytest import raises
|
2 | 4 |
|
3 | 5 | from graphql import graphql_sync
|
4 |
| -from graphql.language import parse, print_ast, DirectiveLocation, DocumentNode, Node |
| 6 | +from graphql.language import parse, print_ast, DirectiveLocation, DocumentNode |
5 | 7 | from graphql.pyutils import dedent
|
6 | 8 | from graphql.type import (
|
7 | 9 | GraphQLArgument,
|
|
17 | 19 | GraphQLInt,
|
18 | 20 | GraphQLInterfaceType,
|
19 | 21 | GraphQLList,
|
| 22 | + GraphQLNamedType, |
20 | 23 | GraphQLNonNull,
|
21 | 24 | GraphQLObjectType,
|
22 | 25 | GraphQLScalarType,
|
@@ -143,9 +146,14 @@ def print_test_schema_changes(extended_schema):
|
143 | 146 | return print_ast(ast)
|
144 | 147 |
|
145 | 148 |
|
146 |
| -def print_node(node: Node) -> str: |
147 |
| - assert node |
148 |
| - return print_ast(node) |
| 149 | +TypeWithAstNode = Union[ |
| 150 | + GraphQLArgument, GraphQLEnumValue, GraphQLField, GraphQLInputField, GraphQLNamedType |
| 151 | +] |
| 152 | + |
| 153 | + |
| 154 | +def print_ast_node(obj: TypeWithAstNode) -> str: |
| 155 | + assert obj is not None and obj.ast_node is not None |
| 156 | + return print_ast(obj.ast_node) |
149 | 157 |
|
150 | 158 |
|
151 | 159 | def describe_extend_schema():
|
@@ -479,53 +487,49 @@ def correctly_assigns_ast_nodes_to_new_and_extended_types():
|
479 | 487 | ) == print_schema(extended_twice_schema)
|
480 | 488 |
|
481 | 489 | new_field = query.fields["newField"]
|
| 490 | + assert print_ast_node(new_field) == "newField(testArg: TestInput): TestEnum" |
| 491 | + assert print_ast_node(new_field.args["testArg"]) == "testArg: TestInput" |
482 | 492 | assert (
|
483 |
| - print_node(new_field.ast_node) == "newField(testArg: TestInput): TestEnum" |
484 |
| - ) |
485 |
| - assert print_node(new_field.args["testArg"].ast_node) == "testArg: TestInput" |
486 |
| - assert ( |
487 |
| - print_node(query.fields["oneMoreNewField"].ast_node) |
| 493 | + print_ast_node(query.fields["oneMoreNewField"]) |
488 | 494 | == "oneMoreNewField: TestUnion"
|
489 | 495 | )
|
490 | 496 |
|
491 | 497 | new_value = some_enum.values["NEW_VALUE"]
|
492 | 498 | assert some_enum
|
493 |
| - assert print_node(new_value.ast_node) == "NEW_VALUE" |
| 499 | + assert print_ast_node(new_value) == "NEW_VALUE" |
494 | 500 |
|
495 | 501 | one_more_new_value = some_enum.values["ONE_MORE_NEW_VALUE"]
|
496 | 502 | assert one_more_new_value
|
497 |
| - assert print_node(one_more_new_value.ast_node) == "ONE_MORE_NEW_VALUE" |
498 |
| - assert print_node(some_input.fields["newField"].ast_node) == "newField: String" |
| 503 | + assert print_ast_node(one_more_new_value) == "ONE_MORE_NEW_VALUE" |
| 504 | + assert print_ast_node(some_input.fields["newField"]) == "newField: String" |
499 | 505 | assert (
|
500 |
| - print_node(some_input.fields["oneMoreNewField"].ast_node) |
| 506 | + print_ast_node(some_input.fields["oneMoreNewField"]) |
501 | 507 | == "oneMoreNewField: String"
|
502 | 508 | )
|
| 509 | + assert print_ast_node(some_interface.fields["newField"]) == "newField: String" |
503 | 510 | assert (
|
504 |
| - print_node(some_interface.fields["newField"].ast_node) == "newField: String" |
505 |
| - ) |
506 |
| - assert ( |
507 |
| - print_node(some_interface.fields["oneMoreNewField"].ast_node) |
| 511 | + print_ast_node(some_interface.fields["oneMoreNewField"]) |
508 | 512 | == "oneMoreNewField: String"
|
509 | 513 | )
|
510 | 514 |
|
511 | 515 | assert (
|
512 |
| - print_node(test_input.fields["testInputField"].ast_node) |
| 516 | + print_ast_node(test_input.fields["testInputField"]) |
513 | 517 | == "testInputField: TestEnum"
|
514 | 518 | )
|
515 | 519 |
|
516 | 520 | test_value = test_enum.values["TEST_VALUE"]
|
517 | 521 | assert test_value
|
518 |
| - assert print_node(test_value.ast_node) == "TEST_VALUE" |
| 522 | + assert print_ast_node(test_value) == "TEST_VALUE" |
519 | 523 |
|
520 | 524 | assert (
|
521 |
| - print_node(test_interface.fields["interfaceField"].ast_node) |
| 525 | + print_ast_node(test_interface.fields["interfaceField"]) |
522 | 526 | == "interfaceField: String"
|
523 | 527 | )
|
524 | 528 | assert (
|
525 |
| - print_node(test_type.fields["interfaceField"].ast_node) |
| 529 | + print_ast_node(test_type.fields["interfaceField"]) |
526 | 530 | == "interfaceField: String"
|
527 | 531 | )
|
528 |
| - assert print_node(test_directive.args["arg"].ast_node) == "arg: Int" |
| 532 | + assert print_ast_node(test_directive.args["arg"]) == "arg: Int" |
529 | 533 |
|
530 | 534 | def builds_types_with_deprecated_fields_and_values():
|
531 | 535 | extended_schema = extend_test_schema(
|
@@ -1150,7 +1154,7 @@ def adds_schema_definition_missing_in_the_original_schema():
|
1150 | 1154 | schema = extend_schema(schema, parse(extension_sdl))
|
1151 | 1155 | query_type = schema.query_type
|
1152 | 1156 | assert query_type.name == "Foo"
|
1153 |
| - assert print_node(schema.ast_node) == extension_sdl.rstrip() |
| 1157 | + assert print_ast_node(schema) == extension_sdl.rstrip() |
1154 | 1158 |
|
1155 | 1159 | def adds_new_root_types_via_schema_extension():
|
1156 | 1160 | schema = extend_test_schema(
|
|
0 commit comments