Skip to content

Commit ca41912

Browse files
committed
fix unit test (hopefully)
1 parent 2cc3933 commit ca41912

File tree

1 file changed

+35
-94
lines changed

1 file changed

+35
-94
lines changed

tests/test_tracer.py

Lines changed: 35 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from unittest.mock import patch
1111

1212
import pytest
13-
14-
from codeflash.tracer import FakeCode, FakeFrame, Tracer
13+
from codeflash.tracing.tracing_new_process import FakeCode, FakeFrame, Tracer
1514

1615

1716
class TestFakeCode:
@@ -54,7 +53,7 @@ def temp_config_file(self) -> Generator[Path, None, None]:
5453
temp_dir = Path(tempfile.mkdtemp())
5554
tests_dir = temp_dir / "tests"
5655
tests_dir.mkdir(exist_ok=True)
57-
56+
5857
# Use the current working directory as module root so test files are included
5958
current_dir = Path.cwd()
6059

@@ -69,6 +68,7 @@ def temp_config_file(self) -> Generator[Path, None, None]:
6968
config_path = Path(f.name)
7069
yield config_path
7170
import shutil
71+
7272
shutil.rmtree(temp_dir, ignore_errors=True)
7373

7474
@pytest.fixture
@@ -94,23 +94,18 @@ def reset_tracer_state(self) -> Generator[None, None, None]:
9494
def test_tracer_disabled_by_environment(self, temp_config_file: Path, temp_trace_file: Path) -> None:
9595
"""Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set."""
9696
with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}):
97-
tracer = Tracer(
98-
output=str(temp_trace_file),
99-
config_file_path=temp_config_file
100-
)
97+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
10198
assert tracer.disable is True
10299

103100
def test_tracer_disabled_with_existing_profiler(self, temp_config_file: Path, temp_trace_file: Path) -> None:
104101
"""Test that tracer is disabled when another profiler is running."""
102+
105103
def dummy_profiler(_frame: object, _event: str, _arg: object) -> object:
106104
return dummy_profiler
107105

108106
sys.setprofile(dummy_profiler)
109107
try:
110-
tracer = Tracer(
111-
output=str(temp_trace_file),
112-
config_file_path=temp_config_file
113-
)
108+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
114109
assert tracer.disable is True
115110
finally:
116111
sys.setprofile(None)
@@ -122,7 +117,7 @@ def test_tracer_initialization_normal(self, temp_config_file: Path, temp_trace_f
122117
functions=["test_func"],
123118
max_function_count=128,
124119
timeout=10,
125-
config_file_path=temp_config_file
120+
config_file_path=temp_config_file,
126121
)
127122

128123
assert tracer.disable is False
@@ -131,37 +126,23 @@ def test_tracer_initialization_normal(self, temp_config_file: Path, temp_trace_f
131126
assert tracer.max_function_count == 128
132127
assert tracer.timeout == 10
133128
assert hasattr(tracer, "_db_lock")
134-
assert getattr(tracer, "_db_lock") is not None
129+
assert tracer._db_lock is not None
135130

136131
def test_tracer_timeout_validation(self, temp_config_file: Path, temp_trace_file: Path) -> None:
137132
with pytest.raises(AssertionError):
138-
Tracer(
139-
output=str(temp_trace_file),
140-
timeout=0,
141-
config_file_path=temp_config_file
142-
)
133+
Tracer(output=str(temp_trace_file), timeout=0, config_file_path=temp_config_file)
143134

144135
with pytest.raises(AssertionError):
145-
Tracer(
146-
output=str(temp_trace_file),
147-
timeout=-5,
148-
config_file_path=temp_config_file
149-
)
136+
Tracer(output=str(temp_trace_file), timeout=-5, config_file_path=temp_config_file)
150137

151138
def test_tracer_context_manager_disabled(self, temp_config_file: Path, temp_trace_file: Path) -> None:
152-
tracer = Tracer(
153-
output=str(temp_trace_file),
154-
disable=True,
155-
config_file_path=temp_config_file
156-
)
139+
tracer = Tracer(output=str(temp_trace_file), disable=True, config_file_path=temp_config_file)
157140

158141
with tracer:
159142
pass
160143

161144
assert not temp_trace_file.exists()
162145

163-
164-
165146
def test_tracer_function_filtering(self, temp_config_file: Path, temp_trace_file: Path) -> None:
166147
"""Test that tracer respects function filtering."""
167148
if hasattr(Tracer, "used_once"):
@@ -173,11 +154,7 @@ def test_function() -> int:
173154
def other_function() -> int:
174155
return 24
175156

176-
tracer = Tracer(
177-
output=str(temp_trace_file),
178-
functions=["test_function"],
179-
config_file_path=temp_config_file
180-
)
157+
tracer = Tracer(output=str(temp_trace_file), functions=["test_function"], config_file_path=temp_config_file)
181158

182159
with tracer:
183160
test_function()
@@ -197,21 +174,16 @@ def other_function() -> int:
197174

198175
con.close()
199176

200-
201177
def test_tracer_max_function_count(self, temp_config_file: Path, temp_trace_file: Path) -> None:
202178
def counting_function(n: int) -> int:
203179
return n * 2
204180

205-
tracer = Tracer(
206-
output=str(temp_trace_file),
207-
max_function_count=3,
208-
config_file_path=temp_config_file
209-
)
181+
tracer = Tracer(output=str(temp_trace_file), max_function_count=3, config_file_path=temp_config_file)
210182

211183
with tracer:
212184
for i in range(5):
213185
counting_function(i)
214-
186+
215187
assert tracer.trace_count <= 3, "Tracer should limit the number of traced functions to max_function_count"
216188

217189
def test_tracer_timeout_functionality(self, temp_config_file: Path, temp_trace_file: Path) -> None:
@@ -222,7 +194,7 @@ def slow_function() -> str:
222194
tracer = Tracer(
223195
output=str(temp_trace_file),
224196
timeout=1, # 1 second timeout
225-
config_file_path=temp_config_file
197+
config_file_path=temp_config_file,
226198
)
227199

228200
with tracer:
@@ -235,10 +207,7 @@ def test_tracer_threading_safety(self, temp_config_file: Path, temp_trace_file:
235207
def thread_function(n: int) -> None:
236208
results.append(n * 2)
237209

238-
tracer = Tracer(
239-
output=str(temp_trace_file),
240-
config_file_path=temp_config_file
241-
)
210+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
242211

243212
with tracer:
244213
threads = []
@@ -254,29 +223,20 @@ def thread_function(n: int) -> None:
254223

255224
def test_simulate_call(self, temp_config_file: Path, temp_trace_file: Path) -> None:
256225
"""Test the simulate_call method."""
257-
tracer = Tracer(
258-
output=str(temp_trace_file),
259-
config_file_path=temp_config_file
260-
)
226+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
261227

262228
tracer.simulate_call("test_simulation")
263229

264230
def test_simulate_cmd_complete(self, temp_config_file: Path, temp_trace_file: Path) -> None:
265231
"""Test the simulate_cmd_complete method."""
266-
tracer = Tracer(
267-
output=str(temp_trace_file),
268-
config_file_path=temp_config_file
269-
)
232+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
270233

271234
tracer.simulate_call("test")
272235
tracer.simulate_cmd_complete()
273236

274237
def test_runctx_method(self, temp_config_file: Path, temp_trace_file: Path) -> None:
275238
"""Test the runctx method for executing code with tracing."""
276-
tracer = Tracer(
277-
output=str(temp_trace_file),
278-
config_file_path=temp_config_file
279-
)
239+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
280240

281241
global_vars = {"x": 10}
282242
local_vars = {}
@@ -291,7 +251,7 @@ def test_tracer_handles_class_methods(self, temp_config_file: Path, temp_trace_f
291251
# Ensure tracer hasn't been used yet in this test
292252
if hasattr(Tracer, "used_once"):
293253
delattr(Tracer, "used_once")
294-
254+
295255
class TestClass:
296256
def instance_method(self) -> str:
297257
return "instance"
@@ -304,32 +264,27 @@ def class_method(cls) -> str:
304264
def static_method() -> str:
305265
return "static"
306266

307-
tracer = Tracer(
308-
output=str(temp_trace_file),
309-
config_file_path=temp_config_file
310-
)
267+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
311268

312269
with tracer:
313270
obj = TestClass()
314271
instance_result = obj.instance_method()
315272
class_result = TestClass.class_method()
316273
static_result = TestClass.static_method()
317-
318274

319-
320275
if temp_trace_file.exists():
321276
con = sqlite3.connect(temp_trace_file)
322277
cursor = con.cursor()
323-
278+
324279
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
325280
if cursor.fetchone():
326281
# Query for all function calls
327282
cursor.execute("SELECT function, classname FROM function_calls")
328283
calls = cursor.fetchall()
329-
284+
330285
function_names = [call[0] for call in calls]
331286
class_names = [call[1] for call in calls if call[1] is not None]
332-
287+
333288
assert "instance_method" in function_names
334289
assert "class_method" in function_names
335290
assert "static_method" in function_names
@@ -338,46 +293,31 @@ def static_method() -> str:
338293
pytest.fail("No function_calls table found in trace file")
339294
con.close()
340295

341-
342-
343-
344-
345296
def test_tracer_handles_exceptions_gracefully(self, temp_config_file: Path, temp_trace_file: Path) -> None:
346297
"""Test that tracer handles exceptions in traced code gracefully."""
298+
347299
def failing_function() -> None:
348300
raise ValueError("Test exception")
349301

350-
tracer = Tracer(
351-
output=str(temp_trace_file),
352-
config_file_path=temp_config_file
353-
)
302+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
354303

355304
with tracer, contextlib.suppress(ValueError):
356305
failing_function()
357306

358-
359-
360-
361-
362307
def test_tracer_with_complex_arguments(self, temp_config_file: Path, temp_trace_file: Path) -> None:
363-
def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x) -> int:
308+
def complex_function(
309+
data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x
310+
) -> int:
364311
return len(data_dict) + len(nested_list)
365312

366-
tracer = Tracer(
367-
output=str(temp_trace_file),
368-
config_file_path=temp_config_file
369-
)
313+
tracer = Tracer(output=str(temp_trace_file), config_file_path=temp_config_file)
370314

371315
expected_dict = {"key": "value", "nested": {"inner": "data"}}
372316
expected_list = [[1, 2], [3, 4], [5, 6]]
373317
expected_func = lambda x: x * 2
374318

375319
with tracer:
376-
complex_function(
377-
expected_dict,
378-
expected_list,
379-
func_arg=expected_func
380-
)
320+
complex_function(expected_dict, expected_list, func_arg=expected_func)
381321

382322
if temp_trace_file.exists():
383323
con = sqlite3.connect(temp_trace_file)
@@ -388,15 +328,16 @@ def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], fu
388328
cursor.execute("SELECT args FROM function_calls WHERE function = 'complex_function'")
389329
result = cursor.fetchone()
390330
assert result is not None, "Function complex_function should be traced"
391-
331+
392332
# Deserialize the arguments
393333
import pickle
334+
394335
traced_args = pickle.loads(result[0])
395-
336+
396337
assert "data_dict" in traced_args
397338
assert "nested_list" in traced_args
398339
assert "func_arg" in traced_args
399-
340+
400341
assert traced_args["data_dict"] == expected_dict
401342
assert traced_args["nested_list"] == expected_list
402343
assert callable(traced_args["func_arg"])

0 commit comments

Comments
 (0)