Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Apr 3, 2025

⚡️ This pull request contains optimizations for PR #59

If you approve this dependent PR, these changes will be merged into the original PR branch codeflash-trace-decorator.

This PR will be automatically closed if the original PR is merged.


📄 212% (2.12x) speedup for CodeFlashBenchmarkPlugin.write_benchmark_timings in codeflash/benchmarking/plugin/plugin.py

⏱️ Runtime : 76.2 milliseconds 24.5 milliseconds (best of 97 runs)

📝 Explanation and details

Here is the optimized version of the provided Python program.

Optimizations made.

  1. Setting SQLite PRAGMAs journal_mode to WAL and synchronous to NORMAL when the connection is created. This can significantly speed up the write operations by using Write-Ahead Logging and reducing the synchronization overhead.

  2. Precompute the SQL INSERT query before the loop to avoid repetitively computing the same string during each execution.

  3. Use benchmark_timings.clear() method instead of reassigning to an empty list to clear the list. It can provide a slight performance benefit.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 145 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage
🌀 Generated Regression Tests Details
from __future__ import annotations

import os
import sqlite3

# imports
import pytest  # used for our unit tests
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin

# unit tests

@pytest.fixture
def setup_database(tmp_path):
    db_path = tmp_path / "test_benchmark.db"
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''CREATE TABLE benchmark_timings (
                        benchmark_module_path TEXT,
                        benchmark_function_name TEXT,
                        benchmark_line_number INTEGER,
                        benchmark_time_ns INTEGER)''')
    conn.commit()
    conn.close()
    return db_path

def test_empty_benchmark_timings(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM benchmark_timings")
    data = cursor.fetchall()
    conn.close()

def test_single_entry_benchmark_timings(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.benchmark_timings = [("module1.py", "function1", 10, 123456)]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM benchmark_timings")
    data = cursor.fetchall()
    conn.close()

def test_multiple_entries_benchmark_timings(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.benchmark_timings = [
        ("module1.py", "function1", 10, 123456),
        ("module2.py", "function2", 20, 654321)
    ]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM benchmark_timings")
    data = cursor.fetchall()
    conn.close()

def test_connection_already_established(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin._connection = sqlite3.connect(plugin._trace_path)
    plugin.benchmark_timings = [("module1.py", "function1", 10, 123456)]
    plugin.write_benchmark_timings()
    cursor = plugin._connection.cursor()
    cursor.execute("SELECT * FROM benchmark_timings")
    data = cursor.fetchall()
    plugin._connection.close()

def test_invalid_database_path():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = "/invalid/path/to/database.db"
    plugin.benchmark_timings = [("module1.py", "function1", 10, 123456)]
    with pytest.raises(sqlite3.OperationalError):
        plugin.write_benchmark_timings()

def test_database_write_error(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.benchmark_timings = [("module1.py", "function1", 10, 123456)]
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("DROP TABLE benchmark_timings")
    conn.commit()
    conn.close()
    with pytest.raises(sqlite3.OperationalError):
        plugin.write_benchmark_timings()

def test_large_data_set(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.benchmark_timings = [("module1.py", "function1", 10, 123456)] * 1000
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("SELECT COUNT(*) FROM benchmark_timings")
    count = cursor.fetchone()[0]
    conn.close()

def test_edge_case_data_values(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.benchmark_timings = [
        ("module1.py", "function1", 10, 999999999999999999),
        ("module@#$.py", "function@#$.py", 0, 0)
    ]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM benchmark_timings")
    data = cursor.fetchall()
    conn.close()

def test_partial_data_write(setup_database):
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = setup_database
    plugin.benchmark_timings = [
        ("module1.py", "function1", 10, 123456),
        ("module2.py", "function2", 20, 654321)
    ]
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("DROP TABLE benchmark_timings")
    conn.commit()
    conn.close()
    with pytest.raises(sqlite3.OperationalError):
        plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM benchmark_timings")
    data = cursor.fetchall()
    conn.close()
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from __future__ import annotations

import os
import sqlite3

# imports
import pytest  # used for our unit tests
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin

# unit tests

# Helper function to setup the database for testing
def setup_database(plugin, schema):
    if os.path.exists(plugin._trace_path):
        os.remove(plugin._trace_path)
    conn = sqlite3.connect(plugin._trace_path)
    conn.execute(schema)
    conn.close()

# Test single entry
def test_single_entry():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_single_entry.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test multiple entries
def test_multiple_entries():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_multiple_entries.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = [
        ('module1', 'func1', 10, 1000),
        ('module2', 'func2', 20, 2000)
    ]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test empty list
def test_empty_list():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_empty_list.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = []
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test no connection pre-established
def test_no_connection_preestablished():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_no_connection_preestablished.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test connection already established
def test_connection_already_established():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_connection_already_established.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin._connection = sqlite3.connect(plugin._trace_path)
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test invalid database path
def test_invalid_database_path():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = '/invalid/path/test_invalid_database_path.db'
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    with pytest.raises(Exception):
        plugin.write_benchmark_timings()

# Test invalid table schema
def test_invalid_table_schema():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_invalid_table_schema.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (id INTEGER PRIMARY KEY)')
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    with pytest.raises(Exception):
        plugin.write_benchmark_timings()

# Test special characters in data
def test_special_characters_in_data():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_special_characters_in_data.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = [('module1', 'func1', 10, "1000' OR '1'='1")]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test large data entries
def test_large_data_entries():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_large_data_entries.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    large_string = 'a' * 10000
    plugin.benchmark_timings = [('module1', 'func1', 10, large_string)]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test simultaneous writes
def test_simultaneous_writes():
    plugin1 = CodeFlashBenchmarkPlugin()
    plugin2 = CodeFlashBenchmarkPlugin()
    plugin1._trace_path = 'test_simultaneous_writes.db'
    plugin2._trace_path = 'test_simultaneous_writes.db'
    setup_database(plugin1, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin1.benchmark_timings = [('module1', 'func1', 10, 1000)]
    plugin2.benchmark_timings = [('module2', 'func2', 20, 2000)]
    from threading import Thread
    t1 = Thread(target=plugin1.write_benchmark_timings)
    t2 = Thread(target=plugin2.write_benchmark_timings)
    t1.start()
    t2.start()
    t1.join()
    t2.join()
    conn = sqlite3.connect(plugin1._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test large number of entries
def test_large_number_of_entries():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_large_number_of_entries.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = [('module1', 'func1', i, i * 1000) for i in range(1000)]
    plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test high frequency writes
def test_high_frequency_writes():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_high_frequency_writes.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    for i in range(100):
        plugin.benchmark_timings = [('module1', 'func1', i, i * 1000)]
        plugin.write_benchmark_timings()
    conn = sqlite3.connect(plugin._trace_path)
    cur = conn.cursor()
    cur.execute("SELECT * FROM benchmark_timings")
    rows = cur.fetchall()
    conn.close()

# Test clear benchmark timings list
def test_clear_benchmark_timings_list():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_clear_benchmark_timings_list.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER, benchmark_time_ns INTEGER)')
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    plugin.write_benchmark_timings()

# Test rollback on failure
def test_rollback_on_failure():
    plugin = CodeFlashBenchmarkPlugin()
    plugin._trace_path = 'test_rollback_on_failure.db'
    setup_database(plugin, 'CREATE TABLE benchmark_timings (id INTEGER PRIMARY KEY)')
    plugin.benchmark_timings = [('module1', 'func1', 10, 1000)]
    with pytest.raises(Exception):
        plugin.write_benchmark_timings()
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr59-2025-04-03T00.01.26 and push.

Codeflash

… by 212% in PR #59 (`codeflash-trace-decorator`)

Here is the optimized version of the provided Python program.



### Optimizations made.
1. Setting SQLite PRAGMAs `journal_mode` to `WAL` and `synchronous` to `NORMAL` when the connection is created. This can significantly speed up the write operations by using Write-Ahead Logging and reducing the synchronization overhead.

2. Precompute the SQL INSERT query before the loop to avoid repetitively computing the same string during each execution.

3. Use `benchmark_timings.clear()` method instead of reassigning to an empty list to clear the list. It can provide a slight performance benefit.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Apr 3, 2025
@codeflash-ai codeflash-ai bot mentioned this pull request Apr 3, 2025
@alvin-r alvin-r closed this Apr 3, 2025
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr59-2025-04-03T00.01.26 branch April 3, 2025 00:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants