Skip to content

Commit 4fa8cc8

Browse files
committed
annotation parser
1 parent 01ac8d0 commit 4fa8cc8

File tree

4 files changed

+376
-3
lines changed

4 files changed

+376
-3
lines changed

.idea/basedtyping.iml

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/watcherTasks.xml

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

basedtyping/runtime_only.py

Lines changed: 285 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,295 @@
66

77
from __future__ import annotations
88

9-
from typing import Final, Final as Final_ext, Literal, Union
9+
import functools
10+
import operator
11+
import sys
12+
import types
13+
from _ast import AST, Attribute, BinOp, BitAnd, Constant, Load, Name, Subscript, Tuple
14+
from ast import NodeTransformer, parse
15+
from types import GenericAlias
16+
from typing import (
17+
Final,
18+
Final as Final_ext,
19+
ForwardRef,
20+
Literal,
21+
Union,
22+
Unpack,
23+
_eval_type,
24+
_Final,
25+
_GenericAlias,
26+
_should_unflatten_callable_args,
27+
_strip_annotations,
28+
_type_check,
29+
)
1030

1131
LiteralType: Final = type(Literal[1])
1232
"""A type that can be used to check if type hints are a ``typing.Literal`` instance"""
1333

1434
# TODO: this is type[object], we need it to be 'SpecialForm[Union]' (or something)
1535
OldUnionType: Final_ext[type[object]] = type(Union[str, int])
1636
"""A type that can be used to check if type hints are a ``typing.Union`` instance."""
37+
38+
39+
def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
40+
if getattr(obj, "__no_type_check__", None):
41+
return {}
42+
# Classes require a special treatment.
43+
if isinstance(obj, type):
44+
hints = {}
45+
for base in reversed(obj.__mro__):
46+
if globalns is None:
47+
base_globals = getattr(
48+
sys.modules.get(base.__module__, None), "__dict__", {}
49+
)
50+
else:
51+
base_globals = globalns
52+
ann = base.__dict__.get("__annotations__", {})
53+
if isinstance(ann, types.GetSetDescriptorType):
54+
ann = {}
55+
base_locals = dict(vars(base)) if localns is None else localns
56+
if localns is None and globalns is None:
57+
# This is surprising, but required. Before Python 3.10,
58+
# get_type_hints only evaluated the globalns of
59+
# a class. To maintain backwards compatibility, we reverse
60+
# the globalns and localns order so that eval() looks into
61+
# *base_globals* first rather than *base_locals*.
62+
# This only affects ForwardRefs.
63+
base_globals, base_locals = base_locals, base_globals
64+
p = BasedTypeParser()
65+
for name, value in ann.items():
66+
if value is None:
67+
value = type(None)
68+
if isinstance(value, str):
69+
value = p.visit(parse(value, mode="eval"))
70+
# value = unparse(p.visit(parse(value)))
71+
value = ForwardRef(value, is_argument=False, is_class=True)
72+
value = _eval_type(value, base_globals, base_locals)
73+
hints[name] = value
74+
return (
75+
hints
76+
if include_extras
77+
else {k: _strip_annotations(t) for k, t in hints.items()}
78+
)
79+
80+
if globalns is None:
81+
if isinstance(obj, types.ModuleType):
82+
globalns = obj.__dict__
83+
else:
84+
nsobj = obj
85+
# Find globalns for the unwrapped object.
86+
while hasattr(nsobj, "__wrapped__"):
87+
nsobj = nsobj.__wrapped__
88+
globalns = getattr(nsobj, "__globals__", {})
89+
if localns is None:
90+
localns = globalns
91+
elif localns is None:
92+
localns = globalns
93+
hints = getattr(obj, "__annotations__", None)
94+
if hints is None:
95+
# Return empty annotations for something that _could_ have them.
96+
if isinstance(obj, _allowed_types):
97+
return {}
98+
else:
99+
raise TypeError(f"{obj!r} is not a module, class, method, or function.")
100+
hints = dict(hints)
101+
for name, value in hints.items():
102+
if value is None:
103+
value = type(None)
104+
if isinstance(value, str):
105+
# class-level forward refs were handled above, this must be either
106+
# a module-level annotation or a function argument annotation
107+
value = ForwardRef(
108+
value, is_argument=not isinstance(obj, types.ModuleType), is_class=False
109+
)
110+
hints[name] = _eval_type(value, globalns, localns)
111+
return (
112+
hints
113+
if include_extras
114+
else {k: _strip_annotations(t) for k, t in hints.items()}
115+
)
116+
117+
118+
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
119+
"""Evaluate all forward references in the given type t.
120+
For use of globalns and localns see the docstring for get_type_hints().
121+
recursive_guard is used to prevent infinite recursion with a recursive
122+
ForwardRef.
123+
"""
124+
if isinstance(t, ForwardRef):
125+
return t._evaluate(globalns, localns, recursive_guard)
126+
if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
127+
if isinstance(t, GenericAlias):
128+
args = tuple(
129+
ForwardRef(arg) if isinstance(arg, str) else arg for arg in t.__args__
130+
)
131+
is_unpacked = t.__unpacked__
132+
if _should_unflatten_callable_args(t, args):
133+
t = t.__origin__[(args[:-1], args[-1])]
134+
else:
135+
t = t.__origin__[args]
136+
if is_unpacked:
137+
t = Unpack[t]
138+
ev_args = tuple(
139+
_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__
140+
)
141+
if ev_args == t.__args__:
142+
return t
143+
if isinstance(t, GenericAlias):
144+
return GenericAlias(t.__origin__, ev_args)
145+
if isinstance(t, types.UnionType):
146+
return functools.reduce(operator.or_, ev_args)
147+
else:
148+
return t.copy_with(ev_args)
149+
return t
150+
151+
152+
class ForwardRef(_Final, _root=True):
153+
"""Internal wrapper to hold a forward reference."""
154+
155+
__slots__ = (
156+
"__forward_arg__",
157+
"__forward_code__",
158+
"__forward_evaluated__",
159+
"__forward_value__",
160+
"__forward_is_argument__",
161+
"__forward_is_class__",
162+
"__forward_module__",
163+
)
164+
165+
def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
166+
if isinstance(arg, str):
167+
# If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
168+
# Unfortunately, this isn't a valid expression on its own, so we
169+
# do the unpacking manually.
170+
if arg[0] == "*":
171+
arg_to_compile = ( # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
172+
f"({arg},)[0]"
173+
)
174+
else:
175+
arg_to_compile = arg
176+
elif isinstance(arg, AST):
177+
arg_to_compile = arg
178+
else:
179+
raise TypeError(f"Forward reference must be a string or AST -- got {arg!r}")
180+
try:
181+
code = compile(arg_to_compile, "<string>", "eval")
182+
except SyntaxError:
183+
raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
184+
except TypeError as t:
185+
print(arg_to_compile.body, t)
186+
...
187+
self.__forward_arg__ = arg
188+
self.__forward_code__ = code
189+
self.__forward_evaluated__ = False
190+
self.__forward_value__ = None
191+
self.__forward_is_argument__ = is_argument
192+
self.__forward_is_class__ = is_class
193+
self.__forward_module__ = module
194+
195+
def _evaluate(self, globalns, localns, recursive_guard):
196+
if self.__forward_arg__ in recursive_guard:
197+
return self
198+
if not self.__forward_evaluated__ or localns is not globalns:
199+
if globalns is None and localns is None:
200+
globalns = localns = {}
201+
elif globalns is None:
202+
globalns = localns
203+
elif localns is None:
204+
localns = globalns
205+
if self.__forward_module__ is not None:
206+
globalns = getattr(
207+
sys.modules.get(self.__forward_module__, None), "__dict__", globalns
208+
)
209+
import typing
210+
211+
import basedtyping
212+
213+
type_ = _type_check(
214+
eval(
215+
self.__forward_code__,
216+
globalns | {"__secret__": typing, "__basedsecret__": basedtyping},
217+
localns,
218+
),
219+
"Forward references must evaluate to types.",
220+
is_argument=self.__forward_is_argument__,
221+
allow_special_forms=self.__forward_is_class__,
222+
)
223+
self.__forward_value__ = _eval_type(
224+
type_, globalns, localns, recursive_guard | {self.__forward_arg__}
225+
)
226+
self.__forward_evaluated__ = True
227+
return self.__forward_value__
228+
229+
def __eq__(self, other):
230+
if not isinstance(other, ForwardRef):
231+
return NotImplemented
232+
if self.__forward_evaluated__ and other.__forward_evaluated__:
233+
return (
234+
self.__forward_arg__ == other.__forward_arg__
235+
and self.__forward_value__ == other.__forward_value__
236+
)
237+
return (
238+
self.__forward_arg__ == other.__forward_arg__
239+
and self.__forward_module__ == other.__forward_module__
240+
)
241+
242+
def __hash__(self):
243+
return hash((self.__forward_arg__, self.__forward_module__))
244+
245+
def __or__(self, other):
246+
return Union[self, other]
247+
248+
def __ror__(self, other):
249+
return Union[other, self]
250+
251+
def __repr__(self):
252+
if self.__forward_module__ is None:
253+
module_repr = ""
254+
else:
255+
module_repr = f", module={self.__forward_module__!r}"
256+
return f"ForwardRef({self.__forward_arg__!r}{module_repr})"
257+
258+
259+
class BasedTypeParser(NodeTransformer):
260+
in_subscript = 0
261+
262+
def __init__(self):
263+
self.load = Load()
264+
265+
def visit_BinOp(self, node: BinOp) -> AST:
266+
if isinstance(node.op, BitAnd):
267+
extra = dict(lineno=node.lineno, col_offset=node.col_offset, ctx=self.load)
268+
return Subscript(
269+
Attribute(Name("__basedsecret__", **extra), "Intersection", **extra),
270+
Tuple([self.visit(node.left), self.visit(node.right)], **extra),
271+
**extra,
272+
)
273+
return self.generic_visit(node)
274+
275+
def visit_Constant(self, node: Constant) -> AST:
276+
if isinstance(node.value, int):
277+
# todo enum
278+
279+
extra = dict(lineno=node.lineno, col_offset=node.col_offset, ctx=self.load)
280+
return Subscript(
281+
Attribute(Name("__secret__", **extra), "Literal", **extra),
282+
node,
283+
**extra,
284+
)
285+
return self.generic_visit(node)
286+
287+
def visit_Tuple(self, node: Tuple) -> AST:
288+
if self.in_subscript:
289+
self.in_subscript = False
290+
return self.generic_visit(node)
291+
extra = dict(lineno=node.lineno, col_offset=node.col_offset, ctx=self.load)
292+
return Subscript(Name("__secret__.Tuple"), self.generic_visit(node), **extra)
293+
294+
def visit_Subscript(self, node: Subscript) -> AST:
295+
if isinstance(node.slice, Tuple):
296+
self.in_subscript = True
297+
try:
298+
return self.generic_visit(node)
299+
finally:
300+
self.in_subscript = False

0 commit comments

Comments
 (0)