Skip to content

Commit a651fbf

Browse files
authored
Merge pull request #48 from codeflash-ai/fix-helper-context-for-init
fixed a bug where helper functions called in the __init__ was not bei…
2 parents 4ce6379 + 26e5923 commit a651fbf

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,25 @@ def get_code_optimization_context(
3030
) -> CodeOptimizationContext:
3131
# Get FunctionSource representation of helpers of FTO
3232
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path)
33+
34+
# Add function to optimize into helpers of FTO dict, as they'll be processed together
35+
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
36+
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)
37+
38+
# Format data to search for helpers of helpers using get_function_sources_from_jedi
3339
helpers_of_fto_qualified_names_dict = {
3440
file_path: {source.qualified_name for source in sources}
3541
for file_path, sources in helpers_of_fto_dict.items()
3642
}
3743

44+
# __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist)
45+
# This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO
46+
for qualified_names in helpers_of_fto_qualified_names_dict.values():
47+
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if '.' in qn})
48+
3849
# Get FunctionSource representation of helpers of helpers of FTO
3950
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path)
4051

41-
# Add function to optimize into helpers of FTO dict, as they'll be processed together
42-
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
43-
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)
44-
4552
# Extract code context for optimization
4653
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code
4754
read_only_code_markdown = extract_code_markdown_context_from_files(

tests/test_code_context_extractor.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def main_method(self):
9999
assert read_write_context.strip() == expected_read_write_context.strip()
100100
assert read_only_context.strip() == expected_read_only_context.strip()
101101

102-
103102
def test_class_method_dependencies() -> None:
104103
file_path = Path(__file__).resolve()
105104

@@ -1260,3 +1259,57 @@ def __repr__(self) -> str:
12601259

12611260
assert read_write_context.strip() == expected_read_write_context.strip()
12621261
assert read_only_context.strip() == expected_read_only_context.strip()
1262+
1263+
def test_indirect_init_helper() -> None:
1264+
code = """
1265+
class MyClass:
1266+
def __init__(self):
1267+
self.x = 1
1268+
self.y = outside_method()
1269+
def target_method(self):
1270+
return self.x + self.y
1271+
1272+
def outside_method():
1273+
return 1
1274+
"""
1275+
with tempfile.NamedTemporaryFile(mode="w") as f:
1276+
f.write(code)
1277+
f.flush()
1278+
file_path = Path(f.name).resolve()
1279+
opt = Optimizer(
1280+
Namespace(
1281+
project_root=file_path.parent.resolve(),
1282+
disable_telemetry=True,
1283+
tests_root="tests",
1284+
test_framework="pytest",
1285+
pytest_cmd="pytest",
1286+
experiment_id=None,
1287+
test_project_root=Path().resolve(),
1288+
)
1289+
)
1290+
function_to_optimize = FunctionToOptimize(
1291+
function_name="target_method",
1292+
file_path=file_path,
1293+
parents=[FunctionParent(name="MyClass", type="ClassDef")],
1294+
starting_line=None,
1295+
ending_line=None,
1296+
)
1297+
1298+
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
1299+
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
1300+
expected_read_write_context = """
1301+
class MyClass:
1302+
def __init__(self):
1303+
self.x = 1
1304+
self.y = outside_method()
1305+
def target_method(self):
1306+
return self.x + self.y
1307+
"""
1308+
expected_read_only_context = f"""
1309+
```python:{file_path.relative_to(opt.args.project_root)}
1310+
def outside_method():
1311+
return 1
1312+
```
1313+
"""
1314+
assert read_write_context.strip() == expected_read_write_context.strip()
1315+
assert read_only_context.strip() == expected_read_only_context.strip()

0 commit comments

Comments
 (0)