Skip to content

Commit eea30a9

Browse files
feat: update dict-to-schema codemod to be generic
Co-Authored-By: [email protected] <[email protected]>
1 parent dd95493 commit eea30a9

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import codegen
2+
from codegen import Codebase
3+
from typing import Dict, Any, List, Union, get_type_hints
4+
from dataclasses import dataclass
5+
import sys
6+
import ast
7+
8+
9+
def infer_type(value) -> str:
10+
"""Infer type hint from a value."""
11+
if isinstance(value, bool):
12+
return "bool"
13+
elif isinstance(value, int):
14+
return "int"
15+
elif isinstance(value, float):
16+
return "float"
17+
elif isinstance(value, str):
18+
return "str"
19+
elif isinstance(value, list):
20+
return "List[Any]"
21+
elif isinstance(value, dict):
22+
return "Dict[str, Any]"
23+
return "Any"
24+
25+
26+
def print_progress(current: int, total: int, width: int = 40) -> None:
27+
filled = int(width * current / total)
28+
bar = "█" * filled + "░" * (width - filled)
29+
percent = int(100 * current / total)
30+
print(f"\r[{bar}] {percent}% ({current}/{total})", end="", file=sys.stderr)
31+
if current == total:
32+
print(file=sys.stderr)
33+
34+
35+
@codegen.function('dict-to-schema')
36+
def run(codebase: Codebase):
37+
"""Convert dictionary literals to dataclasses with proper type hints."""
38+
files_modified = 0
39+
models_created = 0
40+
41+
# Process all Python files in the codebase
42+
total_files = len([f for f in codebase.files if f.path.endswith('.py')])
43+
print("\n\033[1;36m📁 Scanning files for dictionary literals...\033[0m")
44+
print(f"Found {total_files} Python files to process")
45+
46+
def process_dict_assignment(source: str, name: str) -> tuple[str, str]:
47+
"""Process dictionary assignment and return model definition and initialization."""
48+
dict_str = source.split("=", 1)[1].strip()
49+
if not dict_str.startswith("{") or not dict_str.endswith("}"):
50+
return None, None
51+
52+
dict_items = parse_dict_str(dict_str)
53+
if not dict_items:
54+
return None, None
55+
56+
class_name = name.title()
57+
fields = []
58+
for key, value, comment in dict_items:
59+
type_hint = infer_type_from_value(value)
60+
field = f" {key}: {type_hint} | None = None"
61+
if comment:
62+
field += f" # {comment}"
63+
fields.append(field)
64+
65+
model_def = f"@dataclass\nclass {class_name}:\n" + "\n".join(fields)
66+
init_code = f"{name} = {class_name}(**{dict_str})"
67+
return model_def, init_code
68+
69+
for i, file in enumerate([f for f in codebase.files if f.path.endswith('.py')], 1):
70+
needs_imports = False
71+
file_modified = False
72+
73+
print_progress(i, total_files)
74+
print(f"\n\033[1;34m🔍 Processing: {file.path}\033[0m")
75+
76+
for global_var in file.global_vars:
77+
try:
78+
def parse_dict_str(dict_str: str) -> list:
79+
"""Parse dictionary string into list of (key, value, comment) tuples."""
80+
items = []
81+
lines = dict_str.strip("{}").split("\n")
82+
for line in lines:
83+
line = line.strip()
84+
if not line or line.startswith("#"):
85+
continue
86+
87+
# Split line into key-value and comment
88+
parts = line.split("#", 1)
89+
kv_part = parts[0].strip().rstrip(",")
90+
comment = parts[1].strip() if len(parts) > 1 else None
91+
92+
if ":" not in kv_part:
93+
continue
94+
95+
key, value = kv_part.split(":", 1)
96+
key = key.strip().strip('"\'')
97+
value = value.strip()
98+
items.append((key, value, comment))
99+
return items
100+
101+
def infer_type_from_value(value: str) -> str:
102+
"""Infer type hint from a string value."""
103+
value = value.strip()
104+
if value.startswith('"') or value.startswith("'"):
105+
return "str"
106+
elif value in ("True", "False"):
107+
return "bool"
108+
elif "." in value and value.replace(".", "").isdigit():
109+
return "float"
110+
elif value.isdigit():
111+
return "int"
112+
return "Any"
113+
114+
if "{" in global_var.source and "}" in global_var.source:
115+
model_def, init_code = process_dict_assignment(global_var.source, global_var.name)
116+
if not model_def:
117+
continue
118+
119+
print("\n" + "═" * 60)
120+
print(f"\033[1;32m🔄 Converting global variable '{global_var.name}' to schema\033[0m")
121+
print("─" * 60)
122+
print("\033[1;34m📝 Original code:\033[0m")
123+
print(f" {global_var.name} = {global_var.value.source}")
124+
print("\n\033[1;35m✨ Generated schema:\033[0m")
125+
print(" " + model_def.replace("\n", "\n "))
126+
print("\n\033[1;32m✅ Updated code:\033[0m")
127+
print(f" {global_var.name} = {class_name}(**{global_var.value.source})")
128+
print("═" * 60)
129+
130+
global_var.file.add_symbol_from_source(model_def + "\n")
131+
global_var.edit(init_code)
132+
needs_imports = True
133+
models_created += 1
134+
file_modified = True
135+
elif "[" in global_var.source and "]" in global_var.source and "{" in global_var.source:
136+
list_str = global_var.source.split("=", 1)[1].strip()
137+
if not list_str.startswith("[") or not list_str.endswith("]"):
138+
continue
139+
140+
dict_start = list_str.find("{")
141+
dict_end = list_str.find("}")
142+
if dict_start == -1 or dict_end == -1:
143+
continue
144+
145+
dict_str = list_str[dict_start:dict_end + 1]
146+
model_def, _ = process_dict_assignment(f"temp = {dict_str}", global_var.name.rstrip('s'))
147+
if not model_def:
148+
continue
149+
150+
list_init = f"[{global_var.name.rstrip('s').title()}(**item) for item in {list_str}]"
151+
152+
print("\n" + "═" * 60)
153+
print(f"\033[1;32m🔄 Converting list items in '{global_var.name}' to schema\033[0m")
154+
print("─" * 60)
155+
print("\033[1;34m📝 Original code:\033[0m")
156+
print(f" {global_var.name} = {global_var.value.source}")
157+
print("\n\033[1;35m✨ Generated schema:\033[0m")
158+
print(" " + model_def.replace("\n", "\n "))
159+
print("\n\033[1;32m✅ Updated code:\033[0m")
160+
print(f" {global_var.name} = {list_init}")
161+
print("═" * 60)
162+
163+
global_var.file.add_symbol_from_source(model_def + "\n")
164+
global_var.edit(list_init)
165+
needs_imports = True
166+
models_created += 1
167+
file_modified = True
168+
except Exception as e:
169+
print(f"\n❌ Error processing global variable '{global_var.name}':")
170+
print(f" {str(e)}")
171+
print(" Skipping this variable and continuing...\n")
172+
173+
if needs_imports:
174+
print(f" ➕ Adding dataclass imports to {file.path}")
175+
file.add_import_from_import_string("from dataclasses import dataclass")
176+
file.add_import_from_import_string("from typing import Any, Dict, List, Optional")
177+
178+
# Process class attributes
179+
for cls in file.classes:
180+
for attr in cls.attributes:
181+
try:
182+
if "{" in attr.source and "}" in attr.source:
183+
model_def, init_code = process_dict_assignment(attr.source, attr.name)
184+
if not model_def:
185+
continue
186+
187+
cls.insert_before(model_def + "\n")
188+
attr.edit(init_code.split("=", 1)[1].strip())
189+
needs_imports = True
190+
models_created += 1
191+
file_modified = True
192+
except Exception as e:
193+
print(f"\n❌ Error processing class attribute '{attr.name}':")
194+
print(f" {str(e)}")
195+
print(" Skipping this attribute and continuing...\n")
196+
197+
if file_modified:
198+
print(f" ✅ Successfully modified {file.path}")
199+
files_modified += 1
200+
201+
print("\n" + "═" * 60)
202+
print("\033[1;35m📊 Summary of Changes\033[0m")
203+
print("═" * 60)
204+
print(f"\033[1;32m✨ Files modified: {files_modified}\033[0m")
205+
print(f"\033[1;32m🔄 Schemas created: {models_created}\033[0m")
206+
print("═" * 60)

0 commit comments

Comments
 (0)