Skip to content

Commit 9c8ffe0

Browse files
committed
disallow default_factory for dataclasses without __init__
1 parent 786cac0 commit 9c8ffe0

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

Lib/dataclasses.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,10 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10091009
# Otherwise it's a field of some type.
10101010
cls_fields.append(_get_field(cls, name, type, kw_only))
10111011

1012+
# Test whether '__init__' is to be auto-generated or if
1013+
# it is provided explicitly by the user.
1014+
has_init_method = init or '__init__' in cls.__dict__
1015+
10121016
for f in cls_fields:
10131017
fields[f.name] = f
10141018

@@ -1018,6 +1022,15 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10181022
# sees a real default value, not a Field.
10191023
if isinstance(getattr(cls, f.name, None), Field):
10201024
if f.default is MISSING:
1025+
# https://github.com/python/cpython/issues/89529
1026+
if f.default_factory is not MISSING and not has_init_method:
1027+
raise ValueError(
1028+
f'specifying default_factory for {f.name!r}'
1029+
f' requires the @dataclass decorator to be'
1030+
f' called with init=True or to implement'
1031+
f' an __init__ method'
1032+
)
1033+
10211034
# If there's no default, delete the class attribute.
10221035
# This happens if we specify field(repr=False), for
10231036
# example (that is, we specified a field object, but

Lib/test/test_dataclasses/__init__.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pickle
1010
import inspect
1111
import builtins
12+
import re
1213
import types
1314
import weakref
1415
import traceback
@@ -18,6 +19,7 @@
1819
from typing import get_type_hints
1920
from collections import deque, OrderedDict, namedtuple, defaultdict
2021
from functools import total_ordering
22+
from itertools import product
2123

2224
import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
2325
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
@@ -1411,6 +1413,61 @@ class C:
14111413
C().x
14121414
self.assertEqual(factory.call_count, 2)
14131415

1416+
def test_default_factory_with_no_init_method(self):
1417+
# See https://github.com/python/cpython/issues/89529.
1418+
1419+
@dataclass
1420+
class BaseWithInit:
1421+
x: list
1422+
1423+
@dataclass(slots=True)
1424+
class BaseWithSlots:
1425+
x: list
1426+
1427+
@dataclass(init=False)
1428+
class BaseWithOutInit:
1429+
x: list
1430+
1431+
@dataclass(init=False, slots=True)
1432+
class BaseWithOutInitWithSlots:
1433+
x: list
1434+
1435+
err = re.escape(
1436+
"specifying default_factory for 'x' requires the "
1437+
"@dataclass decorator to be called with init=True "
1438+
"or to implement an __init__ method"
1439+
)
1440+
1441+
for base_class, slots, field_init in product(
1442+
(object, BaseWithInit, BaseWithSlots,
1443+
BaseWithOutInit, BaseWithOutInitWithSlots),
1444+
(True, False),
1445+
(True, False),
1446+
):
1447+
with self.subTest('generated __init__', base_class=base_class,
1448+
init=True, slots=slots, field_init=field_init):
1449+
@dataclass(init=True, slots=slots)
1450+
class C(base_class):
1451+
x: list = field(init=field_init, default_factory=list)
1452+
self.assertListEqual(C().x, [])
1453+
1454+
with self.subTest('user-defined __init__', base_class=base_class,
1455+
init=False, slots=slots, field_init=field_init):
1456+
@dataclass(init=False, slots=slots)
1457+
class C(base_class):
1458+
x: list = field(init=field_init, default_factory=list)
1459+
def __init__(self, *a, **kw):
1460+
# deliberately use something else
1461+
self.x = 'hello'
1462+
self.assertEqual(C().x, 'hello')
1463+
1464+
with self.subTest('no generated __init__', base_class=base_class,
1465+
init=False, slots=slots, field_init=field_init):
1466+
with self.assertRaisesRegex(ValueError, err):
1467+
@dataclass(init=False, slots=slots)
1468+
class C(base_class):
1469+
x: list = field(init=field_init, default_factory=list)
1470+
14141471
def test_default_factory_not_called_if_value_given(self):
14151472
# We need a factory that we can test if it's been called.
14161473
factory = Mock()

0 commit comments

Comments
 (0)