Skip to content

Commit 697ce76

Browse files
committed
Improved typing
1 parent 9b7e374 commit 697ce76

30 files changed

+316
-236
lines changed

graphql/error/format_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def format_error(error):
1212
# type: (Union[GraphQLError, GraphQLLocatedError]) -> Dict[str, Any]
13-
formatted_error = {"message": text_type(error)}
13+
formatted_error = {"message": text_type(error)} # type: Dict[str, Any]
1414
if isinstance(error, GraphQLError):
1515
if error.locations is not None:
1616
formatted_error["locations"] = [

graphql/error/located_error.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ def __init__(
2121
try:
2222
message = str(original_error)
2323
except UnicodeEncodeError:
24-
message = original_error.message.encode("utf-8")
24+
message = original_error.message.encode("utf-8") # type: ignore
2525
else:
2626
message = "An unknown error occurred."
2727

28-
if hasattr(original_error, "stack"):
29-
stack = original_error.stack
30-
else:
28+
stack = original_error and getattr(original_error, "stack", None)
29+
if not stack:
3130
stack = sys.exc_info()[2]
3231

3332
super(GraphQLLocatedError, self).__init__(

graphql/language/ast.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,9 +462,8 @@ def __eq__(self, other):
462462
# type: (Any) -> bool
463463
return self is other or (
464464
isinstance(other, Variable)
465-
and
466-
# self.loc == other.loc and
467-
self.name == other.name
465+
and self.name == other.name
466+
# and self.loc == other.loc
468467
)
469468

470469
def __repr__(self):

graphql/language/lexer.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..error import GraphQLSyntaxError
66

77
if False: # flake8: noqa
8-
from typing import Optional
8+
from typing import Optional, Any, List
99
from .source import Source
1010

1111
__all__ = ["Token", "Lexer", "TokenKind", "get_token_desc", "get_token_kind_desc"]
@@ -28,9 +28,10 @@ def __repr__(self):
2828
)
2929

3030
def __eq__(self, other):
31-
# type: (Token) -> bool
31+
# type: (Any) -> bool
3232
return (
33-
self.kind == other.kind
33+
isinstance(other, Token)
34+
and self.kind == other.kind
3435
and self.start == other.start
3536
and self.end == other.end
3637
and self.value == other.value
@@ -163,29 +164,33 @@ def read_token(source, from_position):
163164
return Token(TokenKind.EOF, position, position)
164165

165166
code = char_code_at(body, position)
167+
if code:
168+
if code < 0x0020 and code not in (0x0009, 0x000A, 0x000D):
169+
raise GraphQLSyntaxError(
170+
source, position, u"Invalid character {}.".format(print_char_code(code))
171+
)
166172

167-
if code < 0x0020 and code not in (0x0009, 0x000A, 0x000D):
168-
raise GraphQLSyntaxError(
169-
source, position, u"Invalid character {}.".format(print_char_code(code))
170-
)
171-
172-
kind = PUNCT_CODE_TO_KIND.get(code)
173-
if kind is not None:
174-
return Token(kind, position, position + 1)
173+
kind = PUNCT_CODE_TO_KIND.get(code)
174+
if kind is not None:
175+
return Token(kind, position, position + 1)
175176

176-
if code == 46: # .
177-
if char_code_at(body, position + 1) == char_code_at(body, position + 2) == 46:
178-
return Token(TokenKind.SPREAD, position, position + 3)
177+
if code == 46: # .
178+
if (
179+
char_code_at(body, position + 1)
180+
== char_code_at(body, position + 2)
181+
== 46
182+
):
183+
return Token(TokenKind.SPREAD, position, position + 3)
179184

180-
elif 65 <= code <= 90 or code == 95 or 97 <= code <= 122:
181-
# A-Z, _, a-z
182-
return read_name(source, position)
185+
elif 65 <= code <= 90 or code == 95 or 97 <= code <= 122:
186+
# A-Z, _, a-z
187+
return read_name(source, position)
183188

184-
elif code == 45 or 48 <= code <= 57: # -, 0-9
185-
return read_number(source, position, code)
189+
elif code == 45 or 48 <= code <= 57: # -, 0-9
190+
return read_number(source, position, code)
186191

187-
elif code == 34: # "
188-
return read_string(source, position)
192+
elif code == 34: # "
193+
return read_string(source, position)
189194

190195
raise GraphQLSyntaxError(
191196
source, position, u"Unexpected character {}.".format(print_char_code(code))
@@ -238,7 +243,7 @@ def position_after_whitespace(body, start_position):
238243

239244

240245
def read_number(source, start, first_code):
241-
# type: (Source, int, int) -> Token
246+
# type: (Source, int, Optional[int]) -> Token
242247
"""Reads a number token from the source file, either a float
243248
or an int depending on whether a decimal point appears.
244249
@@ -341,26 +346,23 @@ def read_string(source, start):
341346

342347
position = start + 1
343348
chunk_start = position
344-
code = 0
345-
value = []
349+
code = 0 # type: Optional[int]
350+
value = [] # type: List[str]
346351
append = value.append
347352

348353
while position < body_length:
349354
code = char_code_at(body, position)
350-
if not (
351-
code is not None
352-
and code
353-
not in (
354-
# LineTerminator
355-
0x000A,
356-
0x000D,
357-
# Quote
358-
34,
359-
)
355+
if code in (
356+
None,
357+
# LineTerminator
358+
0x000A,
359+
0x000D,
360+
# Quote
361+
34,
360362
):
361363
break
362364

363-
if code < 0x0020 and code != 0x0009:
365+
if code < 0x0020 and code != 0x0009: # type: ignore
364366
raise GraphQLSyntaxError(
365367
source,
366368
position,
@@ -372,7 +374,7 @@ def read_string(source, start):
372374
append(body[chunk_start : position - 1])
373375

374376
code = char_code_at(body, position)
375-
escaped = ESCAPED_CHAR_CODES.get(code)
377+
escaped = ESCAPED_CHAR_CODES.get(code) # type: ignore
376378
if escaped is not None:
377379
append(escaped)
378380

@@ -399,7 +401,9 @@ def read_string(source, start):
399401
raise GraphQLSyntaxError(
400402
source,
401403
position,
402-
u"Invalid character escape sequence: \\{}.".format(unichr(code)),
404+
u"Invalid character escape sequence: \\{}.".format(
405+
unichr(code) # type: ignore
406+
),
403407
)
404408

405409
position += 1

graphql/language/printer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ class PrintingVisitor(Visitor):
5656

5757
def leave_Name(self, node, *args):
5858
# type: (Name, *Any) -> str
59-
return node.value
59+
return node.value # type: ignore
6060

6161
def leave_Variable(self, node, *args):
6262
# type: (Variable, *Any) -> str
63-
return "$" + node.name
63+
return "$" + node.name # type: ignore
6464

6565
def leave_Document(self, node, *args):
6666
# type: (Document, *Any) -> str
67-
return join(node.definitions, "\n\n") + "\n"
67+
return join(node.definitions, "\n\n") + "\n" # type: ignore
6868

6969
def leave_OperationDefinition(self, node, *args):
7070
# type: (OperationDefinition, *Any) -> str

graphql/language/tests/test_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22

3-
from graphql.language.visitor_meta import QUERY_DOCUMENT_KEYS, VisitorMeta
3+
from graphql.language.visitor_meta import QUERY_DOCUMENT_KEYS
44

55

66
def test_ast_is_hashable():

graphql/language/tests/test_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from typing import Union
2929
from graphql.language.ast import Field
3030
from graphql.language.ast import Name
31-
from graphql.language.visitor import Falsey
31+
from graphql.language.visitor import _Falsey
3232
from typing import List
3333
from graphql.language.ast import Argument
3434
from graphql.language.ast import IntValue

graphql/language/visitor.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .printer import PrintingVisitor
1414

1515

16-
class Falsey(object):
16+
class _Falsey(object):
1717
def __nonzero__(self):
1818
return False
1919

@@ -22,15 +22,19 @@ def __bool__(self):
2222
return False
2323

2424

25-
BREAK = object()
26-
REMOVE = Falsey()
25+
class _Break(object):
26+
pass
27+
28+
29+
BREAK = _Break()
30+
REMOVE = _Falsey()
2731

2832

2933
class Stack(object):
3034
__slots__ = "in_array", "index", "keys", "edits", "prev"
3135

3236
def __init__(self, in_array, index, keys, edits, prev):
33-
# type: (bool, int, Any, List[Tuple[str, str]], Optional[Stack]) -> None
37+
# type: (bool, int, Any, List[Tuple[Union[str, int], Any]], Optional[Stack]) -> None
3438
self.in_array = in_array
3539
self.index = index
3640
self.keys = keys
@@ -46,7 +50,7 @@ def visit(root, visitor, key_map=None):
4650
in_array = isinstance(root, list)
4751
keys = [root]
4852
index = -1
49-
edits = [] # type: List[Tuple[int, Any]]
53+
edits = [] # type: List[Tuple[Union[str, int], Any]]
5054
parent = None # type: Optional[Node]
5155
path = [] # type: List
5256
ancestors = [] # type: List[Node]
@@ -75,7 +79,7 @@ def visit(root, visitor, key_map=None):
7579
edit_offset = 0
7680
for edit_key, edit_value in edits:
7781
if in_array:
78-
edit_key -= edit_offset
82+
edit_key -= edit_offset # type: ignore
7983

8084
if in_array and edit_value is REMOVE:
8185
node.pop(edit_key) # type: ignore
@@ -84,7 +88,6 @@ def visit(root, visitor, key_map=None):
8488
else:
8589
if isinstance(node, list):
8690
node[edit_key] = edit_value
87-
8891
else:
8992
setattr(node, edit_key, edit_value) # type: ignore
9093

@@ -99,7 +102,6 @@ def visit(root, visitor, key_map=None):
99102
key = index if in_array else keys[index]
100103
if isinstance(parent, list):
101104
node = parent[key]
102-
103105
else:
104106
node = getattr(parent, key, None)
105107

@@ -148,7 +150,11 @@ def visit(root, visitor, key_map=None):
148150
if not is_leaving:
149151
stack = Stack(in_array, index, keys, edits, stack)
150152
in_array = isinstance(node, list)
151-
keys = node if in_array else visitor_keys.get(type(node), None) or []
153+
keys = (
154+
node
155+
if in_array
156+
else visitor_keys.get(type(node), None) or [] # type: ignore
157+
) # type: ignore
152158
index = -1
153159
edits = []
154160

@@ -178,10 +184,11 @@ def enter(
178184
path, # type: List[Union[int, str]]
179185
ancestors, # type: List[Any]
180186
):
181-
# type: (...) -> Optional[Any]
182-
method = self._get_enter_handler(type(node))
187+
# type: (...) -> Any
188+
method = self._get_enter_handler(type(node)) # type: ignore
183189
if method:
184190
return method(self, node, key, parent, path, ancestors)
191+
return None
185192

186193
def leave(
187194
self,
@@ -191,10 +198,11 @@ def leave(
191198
path, # type: List[Union[int, str]]
192199
ancestors, # type: List[Any]
193200
):
194-
# type: (...) -> Optional[Any]
195-
method = self._get_leave_handler(type(node))
201+
# type: (...) -> Any
202+
method = self._get_leave_handler(type(node)) # type: ignore
196203
if method:
197204
return method(self, node, key, parent, path, ancestors)
205+
return None
198206

199207

200208
class ParallelVisitor(Visitor):
@@ -203,7 +211,10 @@ class ParallelVisitor(Visitor):
203211
def __init__(self, visitors):
204212
# type: (List[Any]) -> None
205213
self.visitors = visitors
206-
self.skipping = [None] * len(visitors)
214+
self.skipping = [None] * len(
215+
visitors
216+
) # type: List[Union[Node, _Break, _Falsey, None]]
217+
return None
207218

208219
def enter(
209220
self,
@@ -213,7 +224,7 @@ def enter(
213224
path, # type: List[Union[int, str]]
214225
ancestors, # type: List[Any]
215226
):
216-
# type: (...) -> Optional[Any]
227+
# type: (...) -> Any
217228
for i, visitor in enumerate(self.visitors):
218229
if not self.skipping[i]:
219230
result = visitor.enter(node, key, parent, path, ancestors)
@@ -223,6 +234,7 @@ def enter(
223234
self.skipping[i] = BREAK
224235
elif result is not None:
225236
return result
237+
return None
226238

227239
def leave(
228240
self,
@@ -232,7 +244,7 @@ def leave(
232244
path, # type: List[Union[int, str]]
233245
ancestors, # type: List[Any]
234246
):
235-
# type: (...) -> Optional[Any]
247+
# type: (...) -> Any
236248
for i, visitor in enumerate(self.visitors):
237249
if not self.skipping[i]:
238250
result = visitor.leave(node, key, parent, path, ancestors)
@@ -242,17 +254,14 @@ def leave(
242254
return result
243255
elif self.skipping[i] == node:
244256
self.skipping[i] = REMOVE
257+
return None
245258

246259

247260
class TypeInfoVisitor(Visitor):
248261
__slots__ = "visitor", "type_info"
249262

250-
def __init__(
251-
self,
252-
type_info, # type: TypeInfo
253-
visitor, # type: Union[TestVisitor, ParallelVisitor, UsageVisitor]
254-
):
255-
# type: (...) -> None
263+
def __init__(self, type_info, visitor):
264+
# type: (TypeInfo, Visitor) -> None
256265
self.type_info = type_info
257266
self.visitor = visitor
258267

0 commit comments

Comments
 (0)