diff --git a/clause_perf_test.cpp b/clause_perf_test.cpp new file mode 100644 index 00000000000..1ffbf36a257 --- /dev/null +++ b/clause_perf_test.cpp @@ -0,0 +1,183 @@ +#include +#include +#include +#include +#include "src/sat/sat_clause.h" +#include "src/sat/sat_types.h" + +using namespace sat; + +class clause_performance_test { +private: + clause_allocator allocator; + std::vector test_clauses; + std::vector test_literals; + std::mt19937 rng; + + // Performance counters + unsigned long contains_literal_calls = 0; + unsigned long contains_var_calls = 0; + unsigned long satisfied_calls = 0; + +public: + clause_performance_test() : rng(42) { + // Generate test literals + for (unsigned v = 0; v < 1000; v++) { + test_literals.push_back(literal(v, false)); + test_literals.push_back(literal(v, true)); + } + + // Generate test clauses of various sizes + create_test_clauses(); + } + + ~clause_performance_test() { + for (clause* c : test_clauses) { + allocator.del_clause(c); + } + } + + void create_test_clauses() { + std::uniform_int_distribution size_dist(2, 20); + std::uniform_int_distribution lit_dist(0, test_literals.size() - 1); + + // Create 10000 test clauses + for (int i = 0; i < 10000; i++) { + unsigned size = size_dist(rng); + std::vector clause_lits; + + for (unsigned j = 0; j < size; j++) { + clause_lits.push_back(test_literals[lit_dist(rng)]); + } + + clause* c = allocator.mk_clause(size, clause_lits.data(), i % 2 == 0); + test_clauses.push_back(c); + } + + std::cout << "Created " << test_clauses.size() << " test clauses" << std::endl; + + // Check cache alignment effectiveness + unsigned aligned_count = 0; + for (clause* c : test_clauses) { + if (c->is_cache_aligned()) aligned_count++; + } + std::cout << "Cache-aligned clauses: " << aligned_count << "/" << test_clauses.size() + << " (" << (100.0 * aligned_count / test_clauses.size()) << "%)" << std::endl; + } + + void benchmark_contains_literal() { + std::uniform_int_distribution clause_dist(0, test_clauses.size() - 1); + std::uniform_int_distribution lit_dist(0, test_literals.size() - 1); + + const int iterations = 1000000; + unsigned hits = 0; + + auto start = std::chrono::high_resolution_clock::now(); + + for (int i = 0; i < iterations; i++) { + clause* c = test_clauses[clause_dist(rng)]; + literal l = test_literals[lit_dist(rng)]; + if (c->contains(l)) hits++; + contains_literal_calls++; + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + + std::cout << "contains(literal) benchmark:" << std::endl; + std::cout << " Iterations: " << iterations << std::endl; + std::cout << " Hits: " << hits << " (" << (100.0 * hits / iterations) << "%)" << std::endl; + std::cout << " Total time: " << duration.count() / 1e6 << " ms" << std::endl; + std::cout << " Average time per call: " << (double)duration.count() / iterations << " ns" << std::endl; + } + + void benchmark_contains_var() { + std::uniform_int_distribution clause_dist(0, test_clauses.size() - 1); + std::uniform_int_distribution var_dist(0, 999); + + const int iterations = 1000000; + unsigned hits = 0; + + auto start = std::chrono::high_resolution_clock::now(); + + for (int i = 0; i < iterations; i++) { + clause* c = test_clauses[clause_dist(rng)]; + bool_var v = var_dist(rng); + if (c->contains(v)) hits++; + contains_var_calls++; + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + + std::cout << "contains(bool_var) benchmark:" << std::endl; + std::cout << " Iterations: " << iterations << std::endl; + std::cout << " Hits: " << hits << " (" << (100.0 * hits / iterations) << "%)" << std::endl; + std::cout << " Total time: " << duration.count() / 1e6 << " ms" << std::endl; + std::cout << " Average time per call: " << (double)duration.count() / iterations << " ns" << std::endl; + } + + void benchmark_satisfied_by() { + // Create a simple model + model m(1000); + std::uniform_int_distribution bool_dist(0, 2); // 0=false, 1=true, 2=undef + + for (unsigned v = 0; v < 1000; v++) { + int val = bool_dist(rng); + m[v] = (val == 0) ? l_false : (val == 1) ? l_true : l_undef; + } + + std::uniform_int_distribution clause_dist(0, test_clauses.size() - 1); + + const int iterations = 500000; + unsigned satisfied = 0; + + auto start = std::chrono::high_resolution_clock::now(); + + for (int i = 0; i < iterations; i++) { + clause* c = test_clauses[clause_dist(rng)]; + if (c->satisfied_by(m)) satisfied++; + satisfied_calls++; + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + + std::cout << "satisfied_by(model) benchmark:" << std::endl; + std::cout << " Iterations: " << iterations << std::endl; + std::cout << " Satisfied: " << satisfied << " (" << (100.0 * satisfied / iterations) << "%)" << std::endl; + std::cout << " Total time: " << duration.count() / 1e6 << " ms" << std::endl; + std::cout << " Average time per call: " << (double)duration.count() / iterations << " ns" << std::endl; + } + + void run_all_benchmarks() { + std::cout << "=== Z3 Clause Management Performance Test ===" << std::endl; + std::cout << "Testing cache-friendly optimizations" << std::endl; + std::cout << std::endl; + + benchmark_contains_literal(); + std::cout << std::endl; + + benchmark_contains_var(); + std::cout << std::endl; + + benchmark_satisfied_by(); + std::cout << std::endl; + + // Memory usage info + std::cout << "Memory usage:" << std::endl; + std::cout << " Clause allocator size: " << allocator.get_allocation_size() << " bytes" << std::endl; + std::cout << " Average clause size: " << (double)allocator.get_allocation_size() / test_clauses.size() << " bytes" << std::endl; + } +}; + +int main() { + try { + clause_performance_test test; + test.run_all_benchmarks(); + return 0; + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } +} \ No newline at end of file diff --git a/src/sat/sat_clause.cpp b/src/sat/sat_clause.cpp index 351a802810e..25599019459 100644 --- a/src/sat/sat_clause.cpp +++ b/src/sat/sat_clause.cpp @@ -60,17 +60,59 @@ namespace sat { } bool clause::contains(literal l) const { - for (literal l2 : *this) - if (l2 == l) - return true; - return false; + // Prefetch the literals array for better cache performance + #ifdef __GNUC__ + if (m_size > 0) __builtin_prefetch(m_lits, 0, 3); + #endif + + // Unroll small cases for better branch prediction + switch (m_size) { + case 0: return false; + case 1: return m_lits[0] == l; + case 2: return m_lits[0] == l || m_lits[1] == l; + case 3: return m_lits[0] == l || m_lits[1] == l || m_lits[2] == l; + default: + // For larger clauses, use unrolled loop for better cache usage + unsigned i = 0; + for (; i + 3 < m_size; i += 4) { + if (m_lits[i] == l || m_lits[i+1] == l || + m_lits[i+2] == l || m_lits[i+3] == l) + return true; + } + for (; i < m_size; i++) { + if (m_lits[i] == l) + return true; + } + return false; + } } bool clause::contains(bool_var v) const { - for (literal l : *this) - if (l.var() == v) - return true; - return false; + // Prefetch the literals array for better cache performance + #ifdef __GNUC__ + if (m_size > 0) __builtin_prefetch(m_lits, 0, 3); + #endif + + // Unroll small cases for better branch prediction + switch (m_size) { + case 0: return false; + case 1: return m_lits[0].var() == v; + case 2: return m_lits[0].var() == v || m_lits[1].var() == v; + case 3: return m_lits[0].var() == v || m_lits[1].var() == v || m_lits[2].var() == v; + default: + // For larger clauses, use unrolled loop for better cache usage + unsigned i = 0; + for (; i + 3 < m_size; i += 4) { + if (m_lits[i].var() == v || m_lits[i+1].var() == v || + m_lits[i+2].var() == v || m_lits[i+3].var() == v) + return true; + } + for (; i < m_size; i++) { + if (m_lits[i].var() == v) + return true; + } + return false; + } } void clause::elim(literal l) { @@ -101,17 +143,41 @@ namespace sat { } bool clause::satisfied_by(model const & m) const { - for (literal l : *this) { - if (l.sign()) { - if (m[l.var()] == l_false) + // Prefetch the literals array for better cache performance + #ifdef __GNUC__ + if (m_size > 0) __builtin_prefetch(m_lits, 0, 3); + #endif + + // Unroll small cases for better branch prediction + switch (m_size) { + case 0: return false; + case 1: { + literal l = m_lits[0]; + return l.sign() ? (m[l.var()] == l_false) : (m[l.var()] == l_true); + } + case 2: { + literal l0 = m_lits[0], l1 = m_lits[1]; + return (l0.sign() ? (m[l0.var()] == l_false) : (m[l0.var()] == l_true)) || + (l1.sign() ? (m[l1.var()] == l_false) : (m[l1.var()] == l_true)); + } + default: + // For larger clauses, use unrolled loop + unsigned i = 0; + for (; i + 3 < m_size; i += 4) { + literal l0 = m_lits[i], l1 = m_lits[i+1], l2 = m_lits[i+2], l3 = m_lits[i+3]; + if ((l0.sign() ? (m[l0.var()] == l_false) : (m[l0.var()] == l_true)) || + (l1.sign() ? (m[l1.var()] == l_false) : (m[l1.var()] == l_true)) || + (l2.sign() ? (m[l2.var()] == l_false) : (m[l2.var()] == l_true)) || + (l3.sign() ? (m[l3.var()] == l_false) : (m[l3.var()] == l_true))) return true; } - else { - if (m[l.var()] == l_true) + for (; i < m_size; i++) { + literal l = m_lits[i]; + if (l.sign() ? (m[l.var()] == l_false) : (m[l.var()] == l_true)) return true; } + return false; } - return false; } clause_offset clause::get_new_offset() const { @@ -224,19 +290,19 @@ namespace sat { } bool clause_wrapper::contains(literal l) const { - unsigned sz = size(); - for (unsigned i = 0; i < sz; i++) - if (operator[](i) == l) - return true; - return false; + if (is_binary()) { + return operator[](0) == l || operator[](1) == l; + } else { + return get_clause()->contains(l); + } } bool clause_wrapper::contains(bool_var v) const { - unsigned sz = size(); - for (unsigned i = 0; i < sz; i++) - if (operator[](i).var() == v) - return true; - return false; + if (is_binary()) { + return operator[](0).var() == v || operator[](1).var() == v; + } else { + return get_clause()->contains(v); + } } std::ostream & operator<<(std::ostream & out, clause_wrapper const & c) { diff --git a/src/sat/sat_clause.h b/src/sat/sat_clause.h index 0129febbf8e..e16089526af 100644 --- a/src/sat/sat_clause.h +++ b/src/sat/sat_clause.h @@ -37,7 +37,7 @@ namespace sat { std::ostream & operator<<(std::ostream & out, clause const & c); - class clause { + class alignas(16) clause { friend class clause_allocator; friend class tmp_clause; unsigned m_id; @@ -103,6 +103,12 @@ namespace sat { bool on_reinit_stack() const { return m_reinit_stack; } void set_reinit_stack(bool f) { m_reinit_stack = f; } + + // Performance monitoring for cache-friendly optimizations + static unsigned get_cache_line_size() { return 64; } // Typical cache line size + bool is_cache_aligned() const { + return reinterpret_cast(this) % 16 == 0; + } }; std::ostream & operator<<(std::ostream & out, clause_vector const & cs);