Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented May 19, 2025

⚡️ This pull request contains optimizations for PR #217

If you approve this dependent PR, these changes will be merged into the original PR branch proper-cleanup.

This PR will be automatically closed if the original PR is merged.


📄 148% (1.48x) speedup for CodeFlashBenchmarkPlugin.write_benchmark_timings in codeflash/benchmarking/plugin/plugin.py

⏱️ Runtime : 25.9 milliseconds 10.5 milliseconds (best of 91 runs)

📝 Explanation and details

Below is an optimized version of your program with respect to the provided line profiler results.
The major bottleneck is self._connection.commit() and, to a lesser extent, cur.executemany(...).
We can greatly accelerate SQLite bulk inserts and commits by.

  • Disabling SQLite's default autocommit mode and wrapping the bulk inserts in a single explicit transaction.
  • Using with self._connection, which ensures a transaction/commit automatically or using begin/commit explicitly.
  • Setting SQLite's PRAGMA synchronous = OFF and PRAGMA journal_mode = MEMORY if durability is not absolutely required, since this will make writes much faster (you may enable this once per connection only).

Note:
– These changes keep the function return value and signature identical.
– Connection and PRAGMA are set only once per connection.
– All existing comments are preserved, and new comments only explain modifications.

Why this is faster.

  • Explicit transaction: self._connection.execute('BEGIN') + one commit() is far faster than relying on SQLite's default behavior.
  • PRAGMA tweaks: synchronous = OFF and journal_mode = MEMORY massively reduce disk sync/write overhead for benchmark data.
  • Batching: Still using executemany for the most efficient bulk insert.
  • Single cursor, closed immediately.

If your use case absolutely requires durability against power loss, remove the two PRAGMA settings (or use WAL and NORMAL modes). This code retains the exact logic and return values, but will be considerably faster in typical benchmarking scenarios.

Correctness verification report:

Test Status
⏪ Replay Tests 🔘 None Found
⚙️ Existing Unit Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
🌀 Generated Regression Tests 65 Passed
📊 Tests Coverage
🌀 Generated Regression Tests Details
from __future__ import annotations

import os
import sqlite3
import tempfile

# imports
import pytest  # used for our unit tests
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin


# Helper function to create a test database with the required table
def create_test_db(path):
    conn = sqlite3.connect(path)
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE benchmark_timings (
            benchmark_module_path TEXT,
            benchmark_function_name TEXT,
            benchmark_line_number INTEGER,
            benchmark_time_ns INTEGER
        )
        """
    )
    conn.commit()
    conn.close()

# Helper function to read all rows from the benchmark_timings table
def read_benchmark_timings(path):
    conn = sqlite3.connect(path)
    cur = conn.cursor()
    cur.execute("SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()
    return rows

# -------------------- UNIT TESTS --------------------

# 1. BASIC TEST CASES

def test_write_single_benchmark_entry(tmp_path):
    # Test writing a single benchmark timing entry
    db_path = tmp_path / "test1.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    plugin.benchmark_timings = [
        ("module.py", "func", 10, 123456)
    ]
    plugin.write_benchmark_timings()
    # Verify the entry was written
    rows = read_benchmark_timings(str(db_path))

def test_write_multiple_benchmark_entries(tmp_path):
    # Test writing multiple entries at once
    db_path = tmp_path / "test2.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    entries = [
        ("mod1.py", "f1", 1, 111),
        ("mod2.py", "f2", 2, 222),
        ("mod3.py", "f3", 3, 333),
    ]
    plugin.benchmark_timings = list(entries)
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

def test_write_with_existing_connection(tmp_path):
    # Test that function works when _connection is already set
    db_path = tmp_path / "test3.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    plugin._connection = sqlite3.connect(str(db_path))
    entries = [
        ("mod.py", "func", 42, 4242)
    ]
    plugin.benchmark_timings = list(entries)
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))
    plugin._connection.close()

def test_write_empty_benchmark_timings(tmp_path):
    # Test that nothing is written and no error is raised when benchmark_timings is empty
    db_path = tmp_path / "test4.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    plugin.benchmark_timings = []
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

# 2. EDGE TEST CASES

def test_write_with_none_trace_path_raises():
    # Should raise an error if _trace_path is None and there is data to write
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = None
    plugin.benchmark_timings = [("mod.py", "func", 1, 100)]
    with pytest.raises(TypeError):
        # sqlite3.connect(None) raises TypeError
        plugin.write_benchmark_timings()

def test_write_with_invalid_db_path(tmp_path):
    # Should raise an error if _trace_path is invalid
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(tmp_path / "nonexistent_dir" / "file.db")
    plugin.benchmark_timings = [("mod.py", "func", 1, 100)]
    with pytest.raises(sqlite3.OperationalError):
        plugin.write_benchmark_timings()

def test_write_with_invalid_benchmark_entry_type(tmp_path):
    # Should raise an error if the entry is not a tuple of the correct length
    db_path = tmp_path / "test5.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    # Wrong type: string instead of tuple
    plugin.benchmark_timings = ["not a tuple"]
    with pytest.raises(sqlite3.InterfaceError):
        plugin.write_benchmark_timings()

def test_write_with_extra_fields_in_entry(tmp_path):
    # Should raise an error if the entry has too many fields
    db_path = tmp_path / "test6.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    # 5 fields instead of 4
    plugin.benchmark_timings = [("a", "b", 1, 2, 3)]
    with pytest.raises(sqlite3.ProgrammingError):
        plugin.write_benchmark_timings()

def test_write_with_null_values(tmp_path):
    # Should work if some fields are None (SQLite allows NULL unless schema says otherwise)
    db_path = tmp_path / "test7.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    plugin.benchmark_timings = [
        (None, None, None, None)
    ]
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

def test_write_with_special_characters(tmp_path):
    # Should correctly write and read entries with special/unicode characters
    db_path = tmp_path / "test8.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    special_entry = ("modüle.py", "fünc✨", 123, 987654321)
    plugin.benchmark_timings = [special_entry]
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

def test_write_with_negative_and_zero_values(tmp_path):
    # Should handle zero and negative numbers
    db_path = tmp_path / "test9.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    entries = [
        ("mod.py", "zero", 0, 0),
        ("mod.py", "neg", -1, -100)
    ]
    plugin.benchmark_timings = list(entries)
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

def test_write_with_duplicate_entries(tmp_path):
    # Should allow duplicate entries (no unique constraint)
    db_path = tmp_path / "test10.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    entry = ("dup.py", "func", 7, 77)
    plugin.benchmark_timings = [entry, entry]
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

def test_write_with_long_strings(tmp_path):
    # Should handle very long strings in module/function name
    db_path = tmp_path / "test11.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    long_str = "a" * 500
    entry = (long_str, long_str, 1, 2)
    plugin.benchmark_timings = [entry]
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

# 3. LARGE SCALE TEST CASES

def test_write_large_number_of_entries(tmp_path):
    # Test writing a large number of entries (1000)
    db_path = tmp_path / "test12.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    entries = [("mod.py", f"func{i}", i, i * 100) for i in range(1000)]
    plugin.benchmark_timings = list(entries)
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))

def test_write_multiple_batches(tmp_path):
    # Test writing in multiple batches (simulate repeated calls)
    db_path = tmp_path / "test13.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    all_entries = []
    for batch in range(5):
        entries = [(f"mod{batch}.py", f"func{batch}_{i}", i, i + batch) for i in range(200)]
        plugin.benchmark_timings = list(entries)
        plugin.write_benchmark_timings()
        all_entries.extend(entries)
    rows = read_benchmark_timings(str(db_path))

def test_write_large_entry_values(tmp_path):
    # Test writing entries with very large integer values
    db_path = tmp_path / "test14.db"
    create_test_db(str(db_path))
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = str(db_path)
    big = 2**62
    entries = [
        ("mod.py", "func", big, big),
        ("mod.py", "func2", -big, -big)
    ]
    plugin.benchmark_timings = list(entries)
    plugin.write_benchmark_timings()
    rows = read_benchmark_timings(str(db_path))
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from __future__ import annotations

import os
import sqlite3
import tempfile

# imports
import pytest  # used for our unit tests
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin


# Helper function to create a temporary SQLite database with the required table
def create_temp_db():
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
    conn = sqlite3.connect(tmp_file.name)
    cur = conn.cursor()
    # Create the benchmark_timings table
    cur.execute("""
        CREATE TABLE benchmark_timings (
            benchmark_module_path TEXT,
            benchmark_function_name TEXT,
            benchmark_line_number INTEGER,
            benchmark_time_ns INTEGER
        )
    """)
    conn.commit()
    conn.close()
    return tmp_file.name

# Helper function to read all rows from the benchmark_timings table
def read_all_timings(db_path):
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute("SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()
    return rows

# ---------------- Basic Test Cases ----------------

def test_write_single_benchmark_timing():
    # Test writing a single timing entry
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [("module.py", "func", 42, 123456)]
    plugin.write_benchmark_timings()
    # Check that the entry was written
    rows = read_all_timings(db_path)
    os.unlink(db_path)

def test_write_multiple_benchmark_timings():
    # Test writing multiple timing entries at once
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [
        ("mod1.py", "foo", 10, 100),
        ("mod2.py", "bar", 20, 200),
        ("mod3.py", "baz", 30, 300),
    ]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)

def test_write_no_benchmark_timings():
    # Test that nothing is written if benchmark_timings is empty
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = []
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)

# ---------------- Edge Test Cases ----------------

def test_write_with_preexisting_connection():
    # Test that function works when _connection is already set
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin._connection = sqlite3.connect(db_path)
    plugin.benchmark_timings = [("edge.py", "edge_func", 99, 9999)]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    plugin._connection.close()
    os.unlink(db_path)

def test_write_with_special_characters():
    # Test that special characters in strings are handled correctly
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [
        ("weird'\"\\mod.py", "fünc\n\t", -1, 0),
    ]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)

def test_write_with_large_and_negative_numbers():
    # Test writing entries with very large and negative numbers
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [
        ("mod.py", "func", -999999999, -1234567890123456789),
        ("mod.py", "func", 999999999, 1234567890123456789),
    ]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)

def test_write_raises_on_missing_table():
    # Test that an exception is raised if the table does not exist
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
    db_path = tmp_file.name
    tmp_file.close()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [("mod.py", "func", 1, 1)]
    with pytest.raises(sqlite3.OperationalError):
        plugin.write_benchmark_timings()
    os.unlink(db_path)

def test_write_raises_on_wrong_tuple_length():
    # Test that an exception is raised if a tuple has the wrong length
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    # Only 3 elements instead of 4
    plugin.benchmark_timings = [("mod.py", "func", 1)]
    with pytest.raises(sqlite3.ProgrammingError):
        plugin.write_benchmark_timings()
    os.unlink(db_path)

def test_write_raises_on_wrong_types():
    # Test that an exception is raised if the types are wrong (e.g., passing a list instead of tuple)
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    # List instead of tuple
    plugin.benchmark_timings = [["mod.py", "func", 1, 1]]
    with pytest.raises(sqlite3.ProgrammingError):
        plugin.write_benchmark_timings()
    os.unlink(db_path)

def test_benchmark_timings_cleared_on_success():
    # Test that benchmark_timings is cleared after a successful write
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [("mod.py", "func", 1, 1)]
    plugin.write_benchmark_timings()
    os.unlink(db_path)

def test_benchmark_timings_not_cleared_on_failure():
    # Test that benchmark_timings is NOT cleared if an exception is raised
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
    db_path = tmp_file.name
    tmp_file.close()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    plugin.benchmark_timings = [("mod.py", "func", 1, 1)]
    with pytest.raises(sqlite3.OperationalError):
        plugin.write_benchmark_timings()
    os.unlink(db_path)

# ---------------- Large Scale Test Cases ----------------

def test_write_large_number_of_benchmark_timings():
    # Test writing a large number of entries (performance/scalability)
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    N = 1000  # upper bound for large scale within instructions
    plugin.benchmark_timings = [
        (f"mod{i}.py", f"func{i}", i, i * 1000) for i in range(N)
    ]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)

def test_write_large_strings():
    # Test writing entries with very large string fields
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    long_str = "x" * 1000
    plugin.benchmark_timings = [
        (long_str, long_str, 1, 1),
        ("short", long_str, 2, 2),
        (long_str, "short", 3, 3),
    ]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)

def test_write_multiple_times_accumulates():
    # Test that repeated calls accumulate data in the database
    db_path = create_temp_db()
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = db_path
    # First batch
    plugin.benchmark_timings = [("modA.py", "funcA", 1, 10)]
    plugin.write_benchmark_timings()
    # Second batch
    plugin.benchmark_timings = [("modB.py", "funcB", 2, 20)]
    plugin.write_benchmark_timings()
    rows = read_all_timings(db_path)
    os.unlink(db_path)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr217-2025-05-19T19.52.21 and push.

Codeflash

… by 148% in PR #217 (`proper-cleanup`)

Below is an optimized version of your program with respect to the provided line profiler results.  
The major bottleneck is `self._connection.commit()` and, to a lesser extent, `cur.executemany(...)`.  
We can **greatly accelerate SQLite bulk inserts** and commits by.

- Disabling SQLite's default autocommit mode and wrapping the bulk inserts in a single explicit transaction.
- Using `with self._connection`, which ensures a transaction/commit automatically or using `begin`/`commit` explicitly.
- Setting SQLite's `PRAGMA synchronous = OFF` and `PRAGMA journal_mode = MEMORY` if durability is not absolutely required, since this will make writes much faster (you may enable this once per connection only).

**Note:**  
– These changes keep the function return value and signature identical.  
– Connection and PRAGMA are set only once per connection.  
– All existing comments are preserved, and new comments only explain modifications.




### Why this is faster.
- **Explicit transaction**: `self._connection.execute('BEGIN')` + one `commit()` is far faster than relying on SQLite's default behavior.
- **PRAGMA tweaks**: `synchronous = OFF` and `journal_mode = MEMORY` massively reduce disk sync/write overhead for benchmark data.
- **Batching**: Still using `executemany` for the most efficient bulk insert.
- **Single cursor, closed immediately**.

*If your use case absolutely requires durability against power loss, remove the two PRAGMA settings (or use `WAL` and `NORMAL` modes). This code retains the exact logic and return values, but will be considerably faster in typical benchmarking scenarios.*
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label May 19, 2025
@KRRT7 KRRT7 force-pushed the proper-cleanup branch from 43d1056 to f513763 Compare May 20, 2025 00:55
@KRRT7 KRRT7 closed this May 20, 2025
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr217-2025-05-19T19.52.21 branch May 20, 2025 05:11
@KRRT7 KRRT7 restored the codeflash/optimize-pr217-2025-05-19T19.52.21 branch May 21, 2025 05:54
@KRRT7 KRRT7 deleted the codeflash/optimize-pr217-2025-05-19T19.52.21 branch May 21, 2025 05:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant