Skip to content

Commit 97a3496

Browse files
committed
Manually resolve string types into ForwardRef in __init__ of a dataclass
1 parent 720924e commit 97a3496

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

Lib/dataclasses.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import abc
1010
import _thread
1111
from types import FunctionType, GenericAlias
12+
from typing import ForwardRef
1213

1314

1415
__all__ = ['dataclass',
@@ -447,10 +448,24 @@ def _field_assign(frozen, name, value, self_name):
447448
return f'{self_name}.{name}={value}'
448449

449450

450-
def _field_init(f, frozen, globals, self_name, slots):
451+
def _field_init(f, frozen, globals, self_name, slots, module):
451452
# Return the text of the line in the body of __init__ that will
452453
# initialize this field.
453454

455+
if f.init and isinstance(f.type, str):
456+
# We need to resolve this string type into a real `ForwardRef` object,
457+
# because otherwise we might end up with unsolvable annotations.
458+
# For example:
459+
# def __init__(self, d: collections.OrderedDict) -> None:
460+
# We won't be able to resolve `collections.OrderedDict`
461+
# with wrong `module` param, when placed in a different module. #45524
462+
try:
463+
f.type = ForwardRef(f.type, module=module, is_class=True)
464+
except SyntaxError:
465+
# We don't want to fail class creation
466+
# when `ForwardRef` cannot be constructed.
467+
pass
468+
454469
default_name = f'_dflt_{f.name}'
455470
if f.default_factory is not MISSING:
456471
if f.init:
@@ -527,7 +542,7 @@ def _init_param(f):
527542

528543

529544
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
530-
self_name, globals, slots):
545+
self_name, globals, slots, module):
531546
# fields contains both real fields and InitVar pseudo-fields.
532547

533548
# Make sure we don't have fields without defaults following fields
@@ -554,7 +569,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
554569

555570
body_lines = []
556571
for f in fields:
557-
line = _field_init(f, frozen, locals, self_name, slots)
572+
line = _field_init(f, frozen, locals, self_name, slots, module)
558573
# line is None means that this field doesn't require
559574
# initialization (it's a pseudo-field). Just skip it.
560575
if line:
@@ -906,7 +921,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
906921
# we're iterating over them, see if any are frozen.
907922
any_frozen_base = False
908923
has_dataclass_bases = False
909-
init_globals = dict(globals)
910924

911925
for b in cls.__mro__[-1:0:-1]:
912926
# Only process classes that have been processed by our
@@ -918,17 +932,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
918932
fields[f.name] = f
919933
if getattr(b, _PARAMS).frozen:
920934
any_frozen_base = True
921-
if has_dataclass_bases:
922-
# If dataclass has other dataclass as a base type,
923-
# it might have existing `__init__` method.
924-
# We need to change its `globals`, because otherwise
925-
# we might end up with unsolvable annotations. For example:
926-
# `def __init__(self, d: collections.OrderedDict) -> None:`
927-
# We won't be able to resolve `collections.OrderedDict`
928-
# with wrong `globals`, when placed in a different module. #45524
929-
super_init = getattr(b, '__init__', None)
930-
if super_init is not None:
931-
init_globals.update(getattr(super_init, '__globals__', {}))
932935

933936
# Annotations that are defined in this class (not in base
934937
# classes). If __annotations__ isn't present, then this class
@@ -1045,8 +1048,9 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10451048
# if possible.
10461049
'__dataclass_self__' if 'self' in fields
10471050
else 'self',
1051+
globals,
10481052
slots,
1049-
init_globals,
1053+
cls.__module__,
10501054
))
10511055

10521056
# Get the fields as a list, and include only real fields. This is

Lib/test/dataclass_textanno2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ class Child(dataclass_textanno.Bar):
1313
custom: Custom
1414

1515

16+
class Foo: # matching name with `dataclass_testanno.Foo`
17+
pass
18+
19+
20+
@dataclasses.dataclass
21+
class WithMatchinNameOverride(dataclass_textanno.Bar):
22+
foo: Foo # we override existing `foo: Foo` with
23+
24+
1625
@dataclasses.dataclass(init=False)
1726
class WithFutureInit(Child):
1827
def __init__(self, foo: dataclass_textanno.Foo, custom: Custom) -> None:

Lib/test/test_dataclasses.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,6 +3221,37 @@ class FutureInitChild(dataclass_textanno2.WithFutureInit):
32213221
},
32223222
)
32233223

3224+
def test_dataclass_proxy_modules_matching_name_override(self):
3225+
# see bpo-45524
3226+
from test import dataclass_textanno2
3227+
from dataclasses import dataclass
3228+
3229+
@dataclass
3230+
class Default(dataclass_textanno2.WithMatchinNameOverride):
3231+
pass
3232+
3233+
classes = [
3234+
Default,
3235+
dataclass_textanno2.WithMatchinNameOverride
3236+
]
3237+
for klass in classes:
3238+
with self.subTest(klass=klass):
3239+
self.assertEqual(
3240+
get_type_hints(klass),
3241+
{
3242+
'foo': dataclass_textanno2.Foo,
3243+
},
3244+
)
3245+
self.assertEqual(get_type_hints(klass.__new__), {})
3246+
self.assertEqual(
3247+
get_type_hints(klass.__init__),
3248+
{
3249+
'foo': dataclass_textanno2.Foo,
3250+
'return': type(None),
3251+
},
3252+
)
3253+
3254+
32243255

32253256
class TestMakeDataclass(unittest.TestCase):
32263257
def test_simple(self):

0 commit comments

Comments
 (0)