Skip to content
23 changes: 20 additions & 3 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,25 @@ def _field_assign(frozen, name, value, self_name):
return f'{self_name}.{name}={value}'


def _field_init(f, frozen, globals, self_name, slots):
def _field_init(f, frozen, globals, self_name, slots, module):
# Return the text of the line in the body of __init__ that will
# initialize this field.

if f.init and isinstance(f.type, str):
from typing import ForwardRef # `typing` is a heavy import
# We need to resolve this string type into a real `ForwardRef` object,
# because otherwise we might end up with unsolvable annotations.
# For example:
# def __init__(self, d: collections.OrderedDict) -> None:
# We won't be able to resolve `collections.OrderedDict`
# with wrong `module` param, when placed in a different module. #45524
try:
f.type = ForwardRef(f.type, module=module, is_class=True)
except SyntaxError:
# We don't want to fail class creation
# when `ForwardRef` cannot be constructed.
pass

default_name = f'_dflt_{f.name}'
if f.default_factory is not MISSING:
if f.init:
Expand Down Expand Up @@ -527,7 +542,7 @@ def _init_param(f):


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
self_name, globals, slots):
self_name, globals, slots, module):
# fields contains both real fields and InitVar pseudo-fields.

# Make sure we don't have fields without defaults following fields
Expand All @@ -554,7 +569,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,

body_lines = []
for f in fields:
line = _field_init(f, frozen, locals, self_name, slots)
line = _field_init(f, frozen, locals, self_name, slots, module)
# line is None means that this field doesn't require
# initialization (it's a pseudo-field). Just skip it.
if line:
Expand Down Expand Up @@ -906,6 +921,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# we're iterating over them, see if any are frozen.
any_frozen_base = False
has_dataclass_bases = False

for b in cls.__mro__[-1:0:-1]:
# Only process classes that have been processed by our
# decorator. That is, they have a _FIELDS attribute.
Expand Down Expand Up @@ -1034,6 +1050,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
else 'self',
globals,
slots,
cls.__module__,
))

# Get the fields as a list, and include only real fields. This is
Expand Down
6 changes: 6 additions & 0 deletions Lib/test/dataclass_textanno.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ class Foo:
@dataclasses.dataclass
class Bar:
foo: Foo


@dataclasses.dataclass(init=False)
class WithFutureInit(Bar):
def __init__(self, foo: Foo) -> None:
pass
28 changes: 28 additions & 0 deletions Lib/test/dataclass_textanno2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

import dataclasses
from test import dataclass_textanno # We need to be sure that `Foo` is not in scope


class Custom:
pass


@dataclasses.dataclass
class Child(dataclass_textanno.Bar):
custom: Custom


class Foo: # matching name with `dataclass_testanno.Foo`
pass


@dataclasses.dataclass
class WithMatchingNameOverride(dataclass_textanno.Bar):
foo: Foo # Existing `foo` annotation should be overridden


@dataclasses.dataclass(init=False)
class WithFutureInit(Child):
def __init__(self, foo: dataclass_textanno.Foo, custom: Custom) -> None:
pass
124 changes: 124 additions & 0 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3128,6 +3128,130 @@ def test_text_annotations(self):
{'foo': dataclass_textanno.Foo,
'return': type(None)})

def test_dataclass_from_another_module(self):
# see bpo-45524
from test import dataclass_textanno
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno.Bar):
pass

@dataclass(init=False)
class WithInitFalse(dataclass_textanno.Bar):
pass

@dataclass(init=False)
class CustomInit(dataclass_textanno.Bar):
def __init__(self, foo: dataclass_textanno.Foo) -> None:
pass

@dataclass
class FutureInitChild(dataclass_textanno.WithFutureInit):
pass

classes = [
Default,
WithInitFalse,
CustomInit,
dataclass_textanno.WithFutureInit,
FutureInitChild,
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{'foo': dataclass_textanno.Foo},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{'foo': dataclass_textanno.Foo, 'return': type(None)},
)

def test_dataclass_from_proxy_module(self):
# see bpo-45524
from test import dataclass_textanno, dataclass_textanno2
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno2.Child):
pass

@dataclass(init=False)
class WithInitFalse(dataclass_textanno2.Child):
pass

@dataclass(init=False)
class CustomInit(dataclass_textanno2.Child):
def __init__(
self,
foo: dataclass_textanno.Foo,
custom: dataclass_textanno2.Custom,
) -> None:
pass

@dataclass
class FutureInitChild(dataclass_textanno2.WithFutureInit):
pass

classes = [
Default,
WithInitFalse,
CustomInit,
dataclass_textanno2.WithFutureInit,
FutureInitChild,
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{
'foo': dataclass_textanno.Foo,
'custom': dataclass_textanno2.Custom,
},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{
'foo': dataclass_textanno.Foo,
'custom': dataclass_textanno2.Custom,
'return': type(None),
},
)

def test_dataclass_proxy_modules_matching_name_override(self):
# see bpo-45524
from test import dataclass_textanno2
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno2.WithMatchingNameOverride):
pass

classes = [
Default,
dataclass_textanno2.WithMatchingNameOverride
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{
'foo': dataclass_textanno2.Foo,
},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{
'foo': dataclass_textanno2.Foo,
'return': type(None),
},
)



class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``get_type_hints()`` failure on ``@dataclass`` hierarchies in different
modules.