Skip to content

Commit 3158f9c

Browse files
committed
end to end test that proves picklepatcher works. example shown is a socket (which is unpickleable) that's used or not used
1 parent 40e416e commit 3158f9c

File tree

13 files changed

+349
-191
lines changed

13 files changed

+349
-191
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
from codeflash.benchmarking.codeflash_trace import codeflash_trace
3+
4+
5+
@codeflash_trace
6+
def bubble_sort_with_unused_socket(data_container):
7+
# Extract the list to sort, leaving the socket untouched
8+
numbers = data_container.get('numbers', []).copy()
9+
10+
return sorted(numbers)
11+
12+
@codeflash_trace
13+
def bubble_sort_with_used_socket(data_container):
14+
# Extract the list to sort, leaving the socket untouched
15+
numbers = data_container.get('numbers', []).copy()
16+
socket = data_container.get('socket')
17+
socket.send("Hello from the optimized function!")
18+
return sorted(numbers)

code_to_optimize/bubble_sort_picklepatch.py renamed to code_to_optimize/bubble_sort_picklepatch_test_used_socket.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,6 @@
1-
def bubble_sort_with_unused_socket(data_container):
2-
"""
3-
Performs a bubble sort on a list within the data_container. The data container has the following schema:
4-
- 'numbers' (list): The list to be sorted.
5-
- 'socket' (socket): A socket
6-
7-
Args:
8-
data_container: A dictionary with at least 'numbers' (list) and 'socket' keys
9-
10-
Returns:
11-
list: The sorted list of numbers
12-
"""
13-
# Extract the list to sort, leaving the socket untouched
14-
numbers = data_container.get('numbers', []).copy()
15-
16-
# Classic bubble sort implementation
17-
n = len(numbers)
18-
for i in range(n):
19-
# Flag to optimize by detecting if no swaps occurred
20-
swapped = False
21-
22-
# Last i elements are already in place
23-
for j in range(0, n - i - 1):
24-
# Swap if the element is greater than the next element
25-
if numbers[j] > numbers[j + 1]:
26-
numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j]
27-
swapped = True
28-
29-
# If no swapping occurred in this pass, the list is sorted
30-
if not swapped:
31-
break
32-
33-
return numbers
34-
1+
from codeflash.benchmarking.codeflash_trace import codeflash_trace
352

3+
@codeflash_trace
364
def bubble_sort_with_used_socket(data_container):
375
"""
386
Performs a bubble sort on a list within the data_container. The data container has the following schema:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import socket
2+
3+
from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket
4+
from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket
5+
6+
def test_socket_picklepatch(benchmark):
7+
s1, s2 = socket.socketpair()
8+
data = {
9+
"numbers": list(reversed(range(500))),
10+
"socket": s1
11+
}
12+
benchmark(bubble_sort_with_unused_socket, data)
13+
14+
def test_used_socket_picklepatch(benchmark):
15+
s1, s2 = socket.socketpair()
16+
data = {
17+
"numbers": list(reversed(range(500))),
18+
"socket": s1
19+
}
20+
benchmark(bubble_sort_with_used_socket, data)

code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

codeflash/benchmarking/codeflash_trace.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import os
33
import pickle
44
import sqlite3
5-
import sys
65
import threading
76
import time
87
from typing import Callable
98

10-
import dill
9+
from codeflash.picklepatch.pickle_patcher import PicklePatcher
1110

1211

1312
class CodeflashTrace:
@@ -147,34 +146,20 @@ def wrapper(*args, **kwargs):
147146
return result
148147

149148
try:
150-
original_recursion_limit = sys.getrecursionlimit()
151-
sys.setrecursionlimit(10000)
152-
# args = dict(args.items())
153-
# if class_name and func.__name__ == "__init__" and "self" in args:
154-
# del args["self"]
155149
# Pickle the arguments
156-
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
157-
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
158-
sys.setrecursionlimit(original_recursion_limit)
159-
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
160-
# Retry with dill if pickle fails. It's slower but more comprehensive
161-
try:
162-
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
163-
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
164-
sys.setrecursionlimit(original_recursion_limit)
165-
166-
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
167-
print(f"Error pickling arguments for function {func.__name__}: {e}")
168-
# Add to the list of function calls without pickled args. Used for timing info only
169-
self._thread_local.active_functions.remove(func_id)
170-
overhead_time = time.thread_time_ns() - end_time
171-
self.function_calls_data.append(
172-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
173-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
174-
overhead_time, None, None)
175-
)
176-
return result
177-
150+
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
151+
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
152+
except Exception as e:
153+
print(f"Error pickling arguments for function {func.__name__}: {e}")
154+
# Add to the list of function calls without pickled args. Used for timing info only
155+
self._thread_local.active_functions.remove(func_id)
156+
overhead_time = time.thread_time_ns() - end_time
157+
self.function_calls_data.append(
158+
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
159+
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
160+
overhead_time, None, None)
161+
)
162+
return result
178163
# Flush to database every 1000 calls
179164
if len(self.function_calls_data) > 1000:
180165
self.write_function_timings()

codeflash/benchmarking/plugin/plugin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
175175
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
176176
# Subtract overhead from total time
177177
overhead = overhead_by_benchmark.get(benchmark_key, 0)
178-
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
179178
result[benchmark_key] = time_ns - overhead
180179

181180
finally:
@@ -267,9 +266,9 @@ def _run_benchmark(self, func, *args, **kwargs):
267266
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
268267
os.environ["CODEFLASH_BENCHMARKING"] = "True"
269268
# Run the function
270-
start = time.thread_time_ns()
269+
start = time.time_ns()
271270
result = func(*args, **kwargs)
272-
end = time.thread_time_ns()
271+
end = time.time_ns()
273272
# Reset the environment variable
274273
os.environ["CODEFLASH_BENCHMARKING"] = "False"
275274

codeflash/benchmarking/replay_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_trace_replay_test_code(
6262
assert test_framework in ["pytest", "unittest"]
6363

6464
# Create Imports
65-
imports = f"""import dill as pickle
65+
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
6666
{"import unittest" if test_framework == "unittest" else ""}
6767
from codeflash.benchmarking.replay_test import get_next_arg_and_return
6868
"""

codeflash/models/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum, IntEnum
1717
from pathlib import Path
1818
from re import Pattern
19-
from typing import Annotated, Any, Optional, Union, cast
19+
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
2222
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
@@ -362,6 +362,7 @@ class FunctionCoverage:
362362
class TestingMode(enum.Enum):
363363
BEHAVIOR = "behavior"
364364
PERFORMANCE = "performance"
365+
LINE_PROFILE = "line_profile"
365366

366367

367368
class VerificationType(str, Enum):
@@ -533,7 +534,7 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
533534
tree.add(
534535
f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}"
535536
)
536-
return
537+
return tree
537538

538539
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
539540

@@ -606,4 +607,4 @@ def __eq__(self, other: object) -> bool:
606607
sys.setrecursionlimit(original_recursion_limit)
607608
return False
608609
sys.setrecursionlimit(original_recursion_limit)
609-
return True
610+
return True

codeflash/picklepatch/pickle_placeholder.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
class PicklePlaceholderAccessError(Exception):
2+
"""Custom exception raised when attempting to access an unpicklable object."""
3+
4+
5+
16
class PicklePlaceholder:
27
"""A placeholder for an object that couldn't be pickled.
38
@@ -22,22 +27,22 @@ def __init__(self, obj_type, obj_str, error_msg, path=None):
2227
self.__dict__["path"] = path if path is not None else []
2328

2429
def __getattr__(self, name):
25-
"""Raise an error when any attribute is accessed."""
30+
"""Raise a custom error when any attribute is accessed."""
2631
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
27-
raise AttributeError(
28-
f"Cannot access attribute '{name}' on unpicklable object at {path_str}. "
32+
raise PicklePlaceholderAccessError(
33+
f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. "
2934
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
3035
)
3136

3237
def __setattr__(self, name, value):
3338
"""Prevent setting attributes."""
34-
self.__getattr__(name) # This will raise an AttributeError
39+
self.__getattr__(name) # This will raise our custom error
3540

3641
def __call__(self, *args, **kwargs):
37-
"""Raise an error when the object is called."""
42+
"""Raise a custom error when the object is called."""
3843
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
39-
raise TypeError(
40-
f"Cannot call unpicklable object at {path_str}. "
44+
raise PicklePlaceholderAccessError(
45+
f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. "
4146
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
4247
)
4348

codeflash/verification/comparator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sentry_sdk
1111

1212
from codeflash.cli_cmds.console import logger
13+
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
1314

1415
try:
1516
import numpy as np
@@ -64,7 +65,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
6465
if len(orig) != len(new):
6566
return False
6667
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
67-
68+
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
69+
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
70+
# The test results should be rejected as the behavior of the unpickleable object is unknown.
71+
logger.debug("Unable to verify behavior of unpickleable object in replay test")
72+
return False
6873
if isinstance(
6974
orig,
7075
(

0 commit comments

Comments
 (0)