Skip to content

Commit 4d53330

Browse files
committed
Merge branch 'jedi_ctx_fix' into codeflash-trace-decorator
2 parents a8d4fda + 64f7927 commit 4d53330

File tree

6 files changed

+14
-5
lines changed

6 files changed

+14
-5
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from concurrent.futures import ThreadPoolExecutor
2+
3+
24
def funcA(number):
35
k = 0
46
for i in range(number * 100):
@@ -9,6 +11,7 @@ def funcA(number):
911
# Use a generator expression directly in join for more efficiency
1012
return " ".join(str(i) for i in range(number))
1113

14+
1215
def test_threadpool() -> None:
1316
pool = ThreadPoolExecutor(max_workers=3)
1417
args = list(range(10, 31, 10))
@@ -19,4 +22,4 @@ def test_threadpool() -> None:
1922

2023

2124
if __name__ == "__main__":
22-
test_threadpool()
25+
test_threadpool()

codeflash/context/code_context_extractor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def get_function_to_optimize_as_function_source(
356356
name.type == "function"
357357
and name.full_name
358358
and name.name == function_to_optimize.function_name
359+
and name.full_name.startswith(name.module_name)
359360
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
360361
):
361362
function_source = FunctionSource(
@@ -410,6 +411,7 @@ def get_function_sources_from_jedi(
410411
and definition.full_name
411412
and definition.type == "function"
412413
and not belongs_to_function_qualified(definition, qualified_function_name)
414+
and definition.full_name.startswith(definition.module_name)
413415
# Avoid nested functions or classes. Only class.function is allowed
414416
and len((qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split(".")) <= 2
415417
):

codeflash/optimization/function_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def belongs_to_class(name: Name, class_name: str) -> bool:
3131
def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool:
3232
"""Check if the given jedi Name is a direct child of the specified function, matched by qualified function name."""
3333
try:
34-
if get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
34+
if name.full_name.startswith(name.module_name) and get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
3535
# Handles function definition and recursive function calls
3636
return False
3737
if name := name.parent():

codeflash/tracer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,18 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
247247
return
248248
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
249249
sys.setprofile(None)
250+
threading.setprofile(None)
250251
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
251252
return
252253
code = frame.f_code
254+
253255
file_name = Path(code.co_filename).resolve()
254256
# TODO : It currently doesn't log the last return call from the first function
255257

256258
if code.co_name in self.ignored_functions:
257259
return
260+
if not file_name.is_relative_to(self.project_root):
261+
return
258262
if not file_name.exists():
259263
return
260264
if self.functions and code.co_name not in self.functions:

tests/scripts/end_to_end_test_tracer_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def run_test(expected_improvement_pct: int) -> bool:
1010
min_improvement_x=0.1,
1111
expected_unit_tests=1,
1212
coverage_expectations=[
13-
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[3, 4, 5, 7, 10]),
13+
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 9, 12]),
1414
],
1515
)
1616
cwd = (

tests/scripts/end_to_end_test_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
204204
return False
205205

206206
functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout)
207-
if not functions_traced or int(functions_traced.group(1)) != 5:
208-
logging.error("Expected 5 traced functions")
207+
if not functions_traced or int(functions_traced.group(1)) != 4:
208+
logging.error("Expected 4 traced functions")
209209
return False
210210

211211
replay_test_path = pathlib.Path(functions_traced.group(2))

0 commit comments

Comments
 (0)