Skip to content

Commit 46253f3

Browse files
committed
Add some more typings for function arguments and return values
Replicates graphql/graphql-js@bbd8429
1 parent bd919e1 commit 46253f3

19 files changed

+130
-90
lines changed

src/graphql/execution/execute.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,9 @@ async def await_result() -> Any:
626626
if self.is_awaitable(completed):
627627
return await completed
628628
return completed
629-
except Exception as error:
630-
self.handle_field_error(error, field_nodes, path, return_type)
629+
except Exception as raw_error:
630+
error = located_error(raw_error, field_nodes, path.as_list())
631+
self.handle_field_error(error, return_type)
631632
return None
632633

633634
return await_result()
@@ -640,26 +641,24 @@ async def await_result() -> Any:
640641
async def await_completed() -> Any:
641642
try:
642643
return await completed
643-
except Exception as error:
644-
self.handle_field_error(error, field_nodes, path, return_type)
644+
except Exception as raw_error:
645+
error = located_error(raw_error, field_nodes, path.as_list())
646+
self.handle_field_error(error, return_type)
645647
return None
646648

647649
return await_completed()
648650

649651
return completed
650-
except Exception as error:
651-
self.handle_field_error(error, field_nodes, path, return_type)
652+
except Exception as raw_error:
653+
error = located_error(raw_error, field_nodes, path.as_list())
654+
self.handle_field_error(error, return_type)
652655
return None
653656

654657
def handle_field_error(
655658
self,
656-
raw_error: Exception,
657-
field_nodes: List[FieldNode],
658-
path: Path,
659+
error: GraphQLError,
659660
return_type: GraphQLOutputType,
660661
) -> None:
661-
error = located_error(raw_error, field_nodes, path.as_list())
662-
663662
# If the field type is non-nullable, then it is resolved without any protection
664663
# from errors, however it still properly locates the error.
665664
if is_non_null_type(return_type):
@@ -796,10 +795,11 @@ async def await_completed(item: Any, item_path: Path) -> Any:
796795
if is_awaitable(completed):
797796
return await completed
798797
return completed
799-
except Exception as error:
800-
self.handle_field_error(
801-
error, field_nodes, item_path, item_type
798+
except Exception as raw_error:
799+
error = located_error(
800+
raw_error, field_nodes, item_path.as_list()
802801
)
802+
self.handle_field_error(error, item_type)
803803
return None
804804

805805
completed_item = await_completed(item, item_path)
@@ -813,15 +813,17 @@ async def await_completed(item: Any, item_path: Path) -> Any:
813813
async def await_completed(item: Any, item_path: Path) -> Any:
814814
try:
815815
return await item
816-
except Exception as error:
817-
self.handle_field_error(
818-
error, field_nodes, item_path, item_type
816+
except Exception as raw_error:
817+
error = located_error(
818+
raw_error, field_nodes, item_path.as_list()
819819
)
820+
self.handle_field_error(error, item_type)
820821
return None
821822

822823
completed_item = await_completed(completed_item, item_path)
823-
except Exception as error:
824-
self.handle_field_error(error, field_nodes, item_path, item_type)
824+
except Exception as raw_error:
825+
error = located_error(raw_error, field_nodes, item_path.as_list())
826+
self.handle_field_error(error, item_type)
825827
completed_item = None
826828

827829
if is_awaitable(completed_item):

src/graphql/language/lexer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
from ..error import GraphQLSyntaxError
44
from .ast import Token
@@ -154,7 +154,9 @@ def read_token(self, prev: Token) -> Token:
154154
col = 1 + pos - self.line_start
155155
return Token(TokenKind.EOF, body_length, body_length, line, col, prev)
156156

157-
def read_comment(self, start: int, line: int, col: int, prev: Token) -> Token:
157+
def read_comment(
158+
self, start: int, line: int, col: int, prev: Optional[Token]
159+
) -> Token:
158160
"""Read a comment token from the source file."""
159161
body = self.source.body
160162
body_length = len(body)
@@ -178,7 +180,7 @@ def read_comment(self, start: int, line: int, col: int, prev: Token) -> Token:
178180
)
179181

180182
def read_number(
181-
self, start: int, char: str, line: int, col: int, prev: Token
183+
self, start: int, char: str, line: int, col: int, prev: Optional[Token]
182184
) -> Token:
183185
"""Reads a number token from the source file.
184186
@@ -253,7 +255,9 @@ def read_digits(self, start: int, char: str) -> int:
253255
)
254256
return position
255257

256-
def read_string(self, start: int, line: int, col: int, prev: Token) -> Token:
258+
def read_string(
259+
self, start: int, line: int, col: int, prev: Optional[Token]
260+
) -> Token:
257261
"""Read a string token from the source file."""
258262
source = self.source
259263
body = source.body
@@ -316,7 +320,9 @@ def read_string(self, start: int, line: int, col: int, prev: Token) -> Token:
316320

317321
raise GraphQLSyntaxError(source, position, "Unterminated string.")
318322

319-
def read_block_string(self, start: int, line: int, col: int, prev: Token) -> Token:
323+
def read_block_string(
324+
self, start: int, line: int, col: int, prev: Optional[Token]
325+
) -> Token:
320326
source = self.source
321327
body = source.body
322328
body_length = len(body)
@@ -364,7 +370,9 @@ def read_block_string(self, start: int, line: int, col: int, prev: Token) -> Tok
364370

365371
raise GraphQLSyntaxError(source, position, "Unterminated string.")
366372

367-
def read_name(self, start: int, line: int, col: int, prev: Token) -> Token:
373+
def read_name(
374+
self, start: int, line: int, col: int, prev: Optional[Token]
375+
) -> Token:
368376
"""Read an alphanumeric + underscore name from the source."""
369377
body = self.source.body
370378
body_length = len(body)

src/graphql/language/printer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def join(strings: Optional[Strings], separator: str = "") -> str:
393393
return separator.join(s for s in strings if s) if strings else ""
394394

395395

396-
def block(strings: Strings) -> str:
396+
def block(strings: Optional[Strings]) -> str:
397397
"""Return strings inside a block.
398398
399399
Given a collection of strings, return a string with each item on its own line,
@@ -402,7 +402,7 @@ def block(strings: Strings) -> str:
402402
return "{\n" + indent(join(strings, "\n")) + "\n}" if strings else ""
403403

404404

405-
def wrap(start: str, string: str, end: str = "") -> str:
405+
def wrap(start: str, string: Optional[str], end: str = "") -> str:
406406
"""Wrap string inside other strings at start and end.
407407
408408
If the string is not None or empty, then wrap with start and end, otherwise return
@@ -425,6 +425,6 @@ def is_multiline(string: str) -> bool:
425425
return "\n" in string
426426

427427

428-
def has_multiline_items(maybe_list: Optional[Strings]) -> bool:
428+
def has_multiline_items(strings: Optional[Strings]) -> bool:
429429
"""Check whether one of the items in the list has multiple lines."""
430-
return any(is_multiline(item) for item in maybe_list) if maybe_list else False
430+
return any(is_multiline(item) for item in strings) if strings else False

src/graphql/type/validate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def add_error(self, error: GraphQLError) -> None:
108108

109109
def validate_root_types(self) -> None:
110110
schema = self.schema
111-
112111
query_type = schema.query_type
113112
if not query_type:
114113
self.report_error("Query root type must be provided.", schema.ast_node)

tests/execution/test_lists.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
from pytest import mark # type: ignore
24

35
from graphql.execution import execute, execute_sync, ExecutionResult
@@ -66,11 +68,11 @@ def does_not_accept_iterable_string_literal_as_a_list_value():
6668

6769

6870
def describe_execute_handles_list_nullability():
69-
async def _complete(list_field, as_type):
71+
async def _complete(list_field: Any, as_type: str) -> ExecutionResult:
7072
schema = build_schema(f"type Query {{ listField: {as_type} }}")
7173
document = parse("{ listField }")
7274

73-
def execute_query(list_value):
75+
def execute_query(list_value: Any) -> Any:
7476
return execute(schema, document, Data(list_value))
7577

7678
result = execute_query(list_field)

tests/execution/test_variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def describe_get_variable_values_limit_maximum_number_of_coercion_errors():
980980

981981
input_value = {"input": [0, 1, 2]}
982982

983-
def _invalid_value_error(value, index):
983+
def _invalid_value_error(value: int, index: int) -> Dict[str, Any]:
984984
return {
985985
"message": "Variable '$input' got invalid value"
986986
f" {value} at 'input[{index}]';"

tests/language/test_block_string.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66

77

8-
def join_lines(*args):
8+
def join_lines(*args: str) -> str:
99
return "\n".join(args)
1010

1111

tests/language/test_lexer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Tuple
22

33
from pytest import raises # type: ignore
44

@@ -9,6 +9,8 @@
99

1010
from ..utils import dedent
1111

12+
Location = Optional[Tuple[int, int]]
13+
1214

1315
def lex_one(s: str) -> Token:
1416
lexer = Lexer(Source(s))
@@ -21,7 +23,7 @@ def lex_second(s: str) -> Token:
2123
return lexer.advance()
2224

2325

24-
def assert_syntax_error(text, message, location):
26+
def assert_syntax_error(text: str, message: str, location: Location) -> None:
2527
with raises(GraphQLSyntaxError) as exc_info:
2628
lex_second(text)
2729
error = exc_info.value

tests/language/test_parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast
1+
from typing import cast, Optional, Tuple
22

33
from pytest import raises # type: ignore
44

@@ -32,8 +32,10 @@
3232
from ..fixtures import kitchen_sink_query # noqa: F401
3333
from ..utils import dedent
3434

35+
Location = Optional[Tuple[int, int]]
3536

36-
def assert_syntax_error(text, message, location):
37+
38+
def assert_syntax_error(text: str, message: str, location: Location) -> None:
3739
with raises(GraphQLSyntaxError) as exc_info:
3840
parse(text)
3941
error = exc_info.value

tests/language/test_schema_parser.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from textwrap import dedent
2+
from typing import List, Optional, Tuple
23

34
from pytest import raises # type: ignore
45

56
from graphql.error import GraphQLSyntaxError
67
from graphql.language import (
8+
ArgumentNode,
79
BooleanValueNode,
810
DirectiveDefinitionNode,
911
DirectiveNode,
@@ -27,14 +29,18 @@
2729
SchemaDefinitionNode,
2830
SchemaExtensionNode,
2931
StringValueNode,
32+
TypeNode,
3033
UnionTypeDefinitionNode,
34+
ValueNode,
3135
parse,
3236
)
3337

3438
from ..fixtures import kitchen_sink_sdl # noqa: F401
3539

40+
Location = Optional[Tuple[int, int]]
3641

37-
def assert_syntax_error(text, message, location):
42+
43+
def assert_syntax_error(text: str, message: str, location: Location) -> None:
3844
with raises(GraphQLSyntaxError) as exc_info:
3945
parse(text)
4046
error = exc_info.value
@@ -43,7 +49,7 @@ def assert_syntax_error(text, message, location):
4349
assert error.locations == [location]
4450

4551

46-
def assert_definitions(body, loc, num=1):
52+
def assert_definitions(body: str, loc: Location, num=1):
4753
doc = parse(body)
4854
assert isinstance(doc, DocumentNode)
4955
assert doc.loc == loc
@@ -53,35 +59,37 @@ def assert_definitions(body, loc, num=1):
5359
return definitions[0] if num == 1 else definitions
5460

5561

56-
def type_node(name, loc):
62+
def type_node(name: str, loc: Location):
5763
return NamedTypeNode(name=name_node(name, loc), loc=loc)
5864

5965

60-
def name_node(name, loc):
66+
def name_node(name: str, loc: Location):
6167
return NameNode(value=name, loc=loc)
6268

6369

64-
def field_node(name, type_, loc):
70+
def field_node(name: NameNode, type_: TypeNode, loc: Location):
6571
return field_node_with_args(name, type_, [], loc)
6672

6773

68-
def field_node_with_args(name, type_, args, loc):
74+
def field_node_with_args(name: NameNode, type_: TypeNode, args: List, loc: Location):
6975
return FieldDefinitionNode(
7076
name=name, arguments=args, type=type_, directives=[], loc=loc, description=None
7177
)
7278

7379

74-
def non_null_type(type_, loc):
80+
def non_null_type(type_: TypeNode, loc: Location):
7581
return NonNullTypeNode(type=type_, loc=loc)
7682

7783

78-
def enum_value_node(name, loc):
84+
def enum_value_node(name: str, loc: Location):
7985
return EnumValueDefinitionNode(
8086
name=name_node(name, loc), directives=[], loc=loc, description=None
8187
)
8288

8389

84-
def input_value_node(name, type_, default_value, loc):
90+
def input_value_node(
91+
name: NameNode, type_: TypeNode, default_value: Optional[ValueNode], loc: Location
92+
):
8593
return InputValueDefinitionNode(
8694
name=name,
8795
type=type_,
@@ -92,29 +100,33 @@ def input_value_node(name, type_, default_value, loc):
92100
)
93101

94102

95-
def boolean_value_node(value, loc):
103+
def boolean_value_node(value: bool, loc: Location):
96104
return BooleanValueNode(value=value, loc=loc)
97105

98106

99-
def string_value_node(value, block, loc):
107+
def string_value_node(value: str, block: Optional[bool], loc: Location):
100108
return StringValueNode(value=value, block=block, loc=loc)
101109

102110

103-
def list_type_node(type_, loc):
111+
def list_type_node(type_: TypeNode, loc: Location):
104112
return ListTypeNode(type=type_, loc=loc)
105113

106114

107-
def schema_extension_node(directives, operation_types, loc):
115+
def schema_extension_node(
116+
directives: List[DirectiveNode],
117+
operation_types: List[OperationTypeDefinitionNode],
118+
loc: Location,
119+
):
108120
return SchemaExtensionNode(
109121
directives=directives, operation_types=operation_types, loc=loc
110122
)
111123

112124

113-
def operation_type_definition(operation, type_, loc):
125+
def operation_type_definition(operation: OperationType, type_: TypeNode, loc: Location):
114126
return OperationTypeDefinitionNode(operation=operation, type=type_, loc=loc)
115127

116128

117-
def directive_node(name, arguments, loc):
129+
def directive_node(name: NameNode, arguments: List[ArgumentNode], loc: Location):
118130
return DirectiveNode(name=name, arguments=arguments, loc=loc)
119131

120132

0 commit comments

Comments
 (0)