Skip to content

Commit 8e02cdd

Browse files
authored
fix-dataclass-defaults can handle populated defaults (#479)
1 parent d44dc34 commit 8e02cdd

File tree

4 files changed

+113
-13
lines changed

4 files changed

+113
-13
lines changed

src/codemodder/codemods/base_visitor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ def node_position(self, node):
6666
def lineno_for_node(self, node):
6767
return self.node_position(node).start.line
6868

69+
def code(self, node: cst.CSTNode) -> str:
70+
"""
71+
Only a cst.Module node has a `code` attribute which converts the node
72+
back to the original code as a str. To get the code for any node,
73+
the suggested approach is to wrap this node in a `cst.Module` node.
74+
"""
75+
module = cst.Module(body=[cst.SimpleStatementLine(body=[cst.Expr(value=node)])])
76+
return module.code
77+
6978

7079
class BaseTransformer(VisitorBasedCodemodCommand, UtilsMixin):
7180
def __init__(

src/core_codemods/fix_dataclass_defaults.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,44 @@ def leave_AnnAssign(
3636
return updated_node
3737

3838
match original_node.value:
39-
# TODO: add support for populated elements
4039
case cst.List(elements=[]) | cst.Dict(elements=[]) | cst.Tuple(elements=[]):
41-
self.add_needed_import("dataclasses", "field")
42-
self.report_change(original_node)
43-
return updated_node.with_changes(
44-
value=cst.parse_expression(
45-
f"field(default_factory={ type(original_node.value).__name__.lower()})"
46-
)
40+
return self.field_with_default_factory(original_node, updated_node)
41+
case (
42+
cst.List(elements=[_, *_])
43+
| cst.Dict(elements=[_, *_])
44+
| cst.Tuple(elements=[_, *_])
45+
):
46+
return self.field_with_default_factory(
47+
original_node, updated_node, empty=False
4748
)
4849
case cst.Call(func=cst.Name(value="set"), args=[]):
49-
self.add_needed_import("dataclasses", "field")
50-
self.report_change(original_node)
51-
return updated_node.with_changes(
52-
value=cst.parse_expression("field(default_factory=set)")
50+
return self.field_with_default_factory(original_node, updated_node)
51+
case cst.Call(func=cst.Name(value="set"), args=[_, *_]):
52+
return self.field_with_default_factory(
53+
original_node, updated_node, empty=False
5354
)
5455
return updated_node
5556

57+
def field_with_default_factory(
58+
self,
59+
original_node: cst.List | cst.Tuple | cst.Dict | cst.Call,
60+
updated_node: cst.List | cst.Tuple | cst.Dict | cst.Call,
61+
empty=True,
62+
):
63+
self.add_needed_import("dataclasses", "field")
64+
self.report_change(original_node)
65+
value = original_node.value
66+
if empty:
67+
expr = (
68+
"field(default_factory=set)"
69+
if isinstance(value, cst.Call)
70+
else f"field(default_factory={type(value).__name__.lower()})"
71+
)
72+
return updated_node.with_changes(value=cst.parse_expression(expr))
73+
74+
expr = f"field(default_factory=lambda: {self.code(value).strip()})"
75+
return updated_node.with_changes(value=cst.parse_expression(expr))
76+
5677
def _has_dataclass_decorator(self, node: cst.ClassDef) -> bool:
5778
for decorator in node.decorators:
5879
if self.find_base_name(decorator.decorator) == "dataclasses.dataclass":

tests/codemods/test_base_visitor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,34 @@ def leave_FunctionDef(
5555
return updated_node
5656

5757

58+
class AssertNodeCode(BaseTransformer):
59+
def __init__(
60+
self,
61+
context,
62+
results,
63+
expected_code,
64+
node_type,
65+
line_exclude=None,
66+
line_include=None,
67+
):
68+
BaseTransformer.__init__(
69+
self, context, results, line_include or [], line_exclude or []
70+
)
71+
self.expected_code = expected_code
72+
self.node_type = node_type
73+
self.checked = False
74+
75+
def on_leave(
76+
self, original_node: cst.CSTNode, updated_node: cst.CSTNode
77+
) -> cst.CSTNode:
78+
match original_node:
79+
case self.node_type():
80+
assert self.code(original_node).strip() == self.expected_code.strip()
81+
self.checked = True
82+
83+
return updated_node
84+
85+
5886
class TestBaseVisitor:
5987
def run_and_assert(self, input_code, expected, line_exclude, line_include):
6088
input_tree = cst.parse_module(input_code)
@@ -125,3 +153,35 @@ def hello(one, *args, **kwargs):
125153
start=CodePosition(line=2, column=0), end=CodePosition(line=2, column=31)
126154
)
127155
self.run_and_assert(input_code, expected_pos)
156+
157+
158+
class TestCodeForNode:
159+
def run_and_assert(self, input_code, expected_code, node_type):
160+
input_tree = cst.parse_module(dedent(input_code))
161+
command_instance = AssertNodeCode(
162+
CodemodContext(), defaultdict(list), expected_code, node_type
163+
)
164+
command_instance.transform_module(input_tree)
165+
if not command_instance.checked:
166+
raise Exception(f"Input code does not contain a {node_type}")
167+
168+
def test_annAssign(self):
169+
input_code = "phones: list = [1, 2, 3]"
170+
self.run_and_assert(input_code, input_code, cst.AnnAssign)
171+
172+
def test_list(self):
173+
input_code = "[1, 'two', Exception]"
174+
self.run_and_assert(input_code, input_code, cst.List)
175+
176+
def test_dict(self):
177+
self.run_and_assert('dict = {"friend": "one"}', '{"friend": "one"}', cst.Dict)
178+
179+
def test_module(self):
180+
input_code = dedent(
181+
"""
182+
# comment
183+
var = 1
184+
print(var)
185+
"""
186+
)
187+
self.run_and_assert(input_code, input_code, cst.Module)

tests/codemods/test_fix_dataclass_defaults.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class Test:
5858
self.run_and_assert(tmpdir, input_code, expected, num_changes=3)
5959

6060
def test_populated_defaults(self, tmpdir):
61-
# TODO: support later using lambda.
6261
input_code = """
6362
import dataclasses
6463
@@ -69,7 +68,18 @@ class Test:
6968
friends: dict = {"friend": "one"}
7069
family: set = set((1, 2, 3))
7170
"""
72-
self.run_and_assert(tmpdir, input_code, input_code)
71+
expected = """
72+
import dataclasses
73+
from dataclasses import field
74+
75+
@dataclasses.dataclass
76+
class Test:
77+
name: str = ""
78+
phones: list = field(default_factory=lambda: [1, 2, 3])
79+
friends: dict = field(default_factory=lambda: {"friend": "one"})
80+
family: set = field(default_factory=lambda: set((1, 2, 3)))
81+
"""
82+
self.run_and_assert(tmpdir, input_code, expected, num_changes=3)
7383

7484
@pytest.mark.parametrize(
7585
"code",

0 commit comments

Comments
 (0)