Skip to content

Commit 713f135

Browse files
committed
let's make it clear it's an sqlite3 db
1 parent 6b7c435 commit 713f135

File tree

7 files changed

+21
-36
lines changed

7 files changed

+21
-36
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def parse_args() -> Namespace:
4747
trace_optimize.add_argument(
4848
"--output",
4949
type=str,
50-
default="codeflash.trace",
51-
help="The file to save the trace to. Default is codeflash.trace.",
50+
default="codeflash.sqlite3",
51+
help="The file to save the trace to. Default is codeflash.sqlite3.",
5252
)
5353
trace_optimize.add_argument(
5454
"--config-file-path",

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def run_benchmarks(
8787
file_path_to_source_code[file] = f.read()
8888
try:
8989
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
90-
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
90+
trace_file = Path(self.args.benchmarks_root) / "benchmarks.sqlite3"
9191
if trace_file.exists():
9292
trace_file.unlink()
9393

codeflash/tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def main(args: Namespace | None = None) -> ArgumentParser:
3535
parser = ArgumentParser(allow_abbrev=False)
36-
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", default="codeflash.trace")
36+
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", default="codeflash.sqlite3")
3737
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
3838
parser.add_argument(
3939
"--max-function-count",
@@ -59,7 +59,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
5959

6060
if args is not None:
6161
parsed_args = args
62-
parsed_args.outfile = getattr(args, "output", "codeflash.trace")
62+
parsed_args.outfile = getattr(args, "output", "codeflash.sqlite3")
6363
parsed_args.only_functions = getattr(args, "only_functions", None)
6464
parsed_args.max_function_count = getattr(args, "max_function_count", 100)
6565
parsed_args.tracer_timeout = getattr(args, "timeout", None)

codeflash/tracing/tracing_new_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130
test_file_path = get_test_file_path(
131131
test_dir=Path(config["tests_root"]), function_name=function_path, test_type="replay"
132132
)
133-
trace_filename = test_file_path.stem + ".trace"
133+
trace_filename = test_file_path.stem + ".sqlite3"
134134
self.output_file = test_file_path.parent / trace_filename
135135
self.result_pickle_file_path = result_pickle_file_path
136136

docs/optimizing-with-codeflash/trace-and-optimize.mdx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Codeflash script optimizer can be used in three ways:
6161
```
6262

6363
The above command should suffice in most situations.
64-
To customize the trace file location you can specify it like `codeflash optimize -o trace_file_path.trace`. Otherwise, it defaults to `codeflash.trace` in the current working directory.
64+
To customize the trace file location you can specify it like `codeflash optimize -o trace_file_path.sqlite3`. Otherwise, it defaults to `codeflash.sqlite3` in the current working directory.
6565

6666
2. **Trace and optimize as two separate steps**
6767

@@ -70,7 +70,7 @@ Codeflash script optimizer can be used in three ways:
7070
To create just the trace file first, run
7171

7272
```bash
73-
codeflash optimize -o trace_file.trace --trace-only path/to/your/file.py --your_options
73+
codeflash optimize -o trace_file.sqlite3 --trace-only path/to/your/file.py --your_options
7474
```
7575

7676
This will create a replay test file. To optimize with the replay test, run the
@@ -89,7 +89,7 @@ Codeflash script optimizer can be used in three ways:
8989
```python
9090
from codeflash.tracer import Tracer
9191

92-
with Tracer(output="codeflash.trace"):
92+
with Tracer():
9393
model.predict() # Your code here
9494
```
9595

@@ -106,6 +106,6 @@ Codeflash script optimizer can be used in three ways:
106106
- `disable`: If set to `True`, the tracer will not trace the code. Default is `False`.
107107
- `max_function_count`: The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.
108108
- `timeout`: The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.
109-
- `output`: The file to save the trace to. Default is `codeflash.trace`.
109+
Note: The trace file location is automatically determined and saved as a `.sqlite3` file.
110110
- `config_file_path`: The path to the `pyproject.toml` file which stores the Codeflash config. This is auto-discovered by default.
111111
You can also disable the tracer in the code by setting the `disable=True` option in the `Tracer` constructor.

tests/test_function_ranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@pytest.fixture
1111
def trace_file():
12-
return Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace"
12+
return Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/codeflash.sqlite3"
1313

1414

1515
@pytest.fixture

tests/test_tracer.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def trace_config(self, tmp_path: Path) -> Generator[TraceConfig, None, None]:
7676
ignore-paths = []
7777
""", encoding="utf-8")
7878

79-
trace_path = tmp_path / "trace_file.trace"
79+
trace_path = tmp_path / "trace_file.sqlite3"
8080
replay_test_pkl_path = tmp_path / "replay_test.pkl"
8181
config, found_config_path = parse_config_file(config_path)
8282
trace_config = TraceConfig(
@@ -104,7 +104,6 @@ def test_tracer_disabled_by_environment(self, trace_config: TraceConfig) -> None
104104
"""Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set."""
105105
with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}):
106106
tracer = Tracer(
107-
output=str(trace_config.trace_file),
108107
config=trace_config.trace_config,
109108
project_root=trace_config.project_root,
110109
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -120,7 +119,6 @@ def dummy_profiler(_frame: object, _event: str, _arg: object) -> object:
120119
sys.setprofile(dummy_profiler)
121120
try:
122121
tracer = Tracer(
123-
output=str(trace_config.trace_file),
124122
config=trace_config.trace_config,
125123
project_root=trace_config.project_root,
126124
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -132,7 +130,6 @@ def dummy_profiler(_frame: object, _event: str, _arg: object) -> object:
132130
def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None:
133131
"""Test normal tracer initialization."""
134132
tracer = Tracer(
135-
output=str(trace_config.trace_file),
136133
config=trace_config.trace_config,
137134
project_root=trace_config.project_root,
138135
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -142,7 +139,7 @@ def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None:
142139
)
143140

144141
assert tracer.disable is False
145-
assert tracer.output_file == trace_config.trace_file.resolve()
142+
assert tracer.output_file.exists() or not tracer.disable # output_file is auto-generated
146143
assert tracer.functions == ["test_func"]
147144
assert tracer.max_function_count == 128
148145
assert tracer.timeout == 10
@@ -152,7 +149,6 @@ def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None:
152149
def test_tracer_timeout_validation(self, trace_config: TraceConfig) -> None:
153150
with pytest.raises(AssertionError):
154151
Tracer(
155-
output=str(trace_config.trace_file),
156152
config=trace_config.trace_config,
157153
project_root=trace_config.project_root,
158154
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -161,7 +157,6 @@ def test_tracer_timeout_validation(self, trace_config: TraceConfig) -> None:
161157

162158
with pytest.raises(AssertionError):
163159
Tracer(
164-
output=str(trace_config.trace_file),
165160
config=trace_config.trace_config,
166161
project_root=trace_config.project_root,
167162
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -170,7 +165,6 @@ def test_tracer_timeout_validation(self, trace_config: TraceConfig) -> None:
170165

171166
def test_tracer_context_manager_disabled(self, trace_config: TraceConfig) -> None:
172167
tracer = Tracer(
173-
output=str(trace_config.trace_file),
174168
config=trace_config.trace_config,
175169
project_root=trace_config.project_root,
176170
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -180,7 +174,8 @@ def test_tracer_context_manager_disabled(self, trace_config: TraceConfig) -> Non
180174
with tracer:
181175
pass
182176

183-
assert not trace_config.trace_file.exists()
177+
# When disabled, tracer should not create any files
178+
assert not tracer.output_file.exists() if hasattr(tracer, 'output_file') else True
184179

185180
def test_tracer_function_filtering(self, trace_config: TraceConfig) -> None:
186181
"""Test that tracer respects function filtering."""
@@ -194,7 +189,6 @@ def other_function() -> int:
194189
return 24
195190

196191
tracer = Tracer(
197-
output=str(trace_config.trace_file),
198192
config=trace_config.trace_config,
199193
project_root=trace_config.project_root,
200194
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -205,8 +199,8 @@ def other_function() -> int:
205199
test_function()
206200
other_function()
207201

208-
if trace_config.trace_file.exists():
209-
con = sqlite3.connect(trace_config.trace_file)
202+
if tracer.output_file.exists():
203+
con = sqlite3.connect(tracer.output_file)
210204
cursor = con.cursor()
211205

212206
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
@@ -224,7 +218,6 @@ def counting_function(n: int) -> int:
224218
return n * 2
225219

226220
tracer = Tracer(
227-
output=str(trace_config.trace_file),
228221
config=trace_config.trace_config,
229222
project_root=trace_config.project_root,
230223
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -243,7 +236,6 @@ def slow_function() -> str:
243236
return "done"
244237

245238
tracer = Tracer(
246-
output=str(trace_config.trace_file),
247239
config=trace_config.trace_config,
248240
project_root=trace_config.project_root,
249241
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -261,7 +253,6 @@ def thread_function(n: int) -> None:
261253
results.append(n * 2)
262254

263255
tracer = Tracer(
264-
output=str(trace_config.trace_file),
265256
config=trace_config.trace_config,
266257
project_root=trace_config.project_root,
267258
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -282,7 +273,6 @@ def thread_function(n: int) -> None:
282273
def test_simulate_call(self, trace_config: TraceConfig) -> None:
283274
"""Test the simulate_call method."""
284275
tracer = Tracer(
285-
output=str(trace_config.trace_file),
286276
config=trace_config.trace_config,
287277
project_root=trace_config.project_root,
288278
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -293,7 +283,6 @@ def test_simulate_call(self, trace_config: TraceConfig) -> None:
293283
def test_simulate_cmd_complete(self, trace_config: TraceConfig) -> None:
294284
"""Test the simulate_cmd_complete method."""
295285
tracer = Tracer(
296-
output=str(trace_config.trace_file),
297286
config=trace_config.trace_config,
298287
project_root=trace_config.project_root,
299288
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -305,7 +294,6 @@ def test_simulate_cmd_complete(self, trace_config: TraceConfig) -> None:
305294
def test_runctx_method(self, trace_config: TraceConfig) -> None:
306295
"""Test the runctx method for executing code with tracing."""
307296
tracer = Tracer(
308-
output=str(trace_config.trace_file),
309297
config=trace_config.trace_config,
310298
project_root=trace_config.project_root,
311299
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -338,7 +326,6 @@ def static_method() -> str:
338326
return "static"
339327

340328
tracer = Tracer(
341-
output=str(trace_config.trace_file),
342329
config=trace_config.trace_config,
343330
project_root=trace_config.project_root,
344331
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -350,8 +337,8 @@ def static_method() -> str:
350337
class_result = TestClass.class_method()
351338
static_result = TestClass.static_method()
352339

353-
if trace_config.trace_file.exists():
354-
con = sqlite3.connect(trace_config.trace_file)
340+
if tracer.output_file.exists():
341+
con = sqlite3.connect(tracer.output_file)
355342
cursor = con.cursor()
356343

357344
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")
@@ -378,7 +365,6 @@ def failing_function() -> None:
378365
raise ValueError("Test exception")
379366

380367
tracer = Tracer(
381-
output=str(trace_config.trace_file),
382368
config=trace_config.trace_config,
383369
project_root=trace_config.project_root,
384370
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -394,7 +380,6 @@ def complex_function(
394380
return len(data_dict) + len(nested_list)
395381

396382
tracer = Tracer(
397-
output=str(trace_config.trace_file),
398383
config=trace_config.trace_config,
399384
project_root=trace_config.project_root,
400385
result_pickle_file_path=trace_config.result_pickle_file_path,
@@ -410,8 +395,8 @@ def complex_function(
410395
pickled = pickle.load(trace_config.result_pickle_file_path.open("rb"))
411396
assert pickled["replay_test_file_path"].exists()
412397

413-
if trace_config.trace_file.exists():
414-
con = sqlite3.connect(trace_config.trace_file)
398+
if tracer.output_file.exists():
399+
con = sqlite3.connect(tracer.output_file)
415400
cursor = con.cursor()
416401

417402
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")

0 commit comments

Comments
 (0)