Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions clause_perf_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#include <iostream>
#include <chrono>
#include <random>
#include <vector>
#include "src/sat/sat_clause.h"
#include "src/sat/sat_types.h"

using namespace sat;

class clause_performance_test {
private:
clause_allocator allocator;
std::vector<clause*> test_clauses;
std::vector<literal> test_literals;
std::mt19937 rng;

// Performance counters
unsigned long contains_literal_calls = 0;
unsigned long contains_var_calls = 0;
unsigned long satisfied_calls = 0;

public:
clause_performance_test() : rng(42) {
// Generate test literals
for (unsigned v = 0; v < 1000; v++) {
test_literals.push_back(literal(v, false));
test_literals.push_back(literal(v, true));
}

// Generate test clauses of various sizes
create_test_clauses();
}

~clause_performance_test() {
for (clause* c : test_clauses) {
allocator.del_clause(c);
}
}

void create_test_clauses() {
std::uniform_int_distribution<unsigned> size_dist(2, 20);
std::uniform_int_distribution<unsigned> lit_dist(0, test_literals.size() - 1);

// Create 10000 test clauses
for (int i = 0; i < 10000; i++) {
unsigned size = size_dist(rng);
std::vector<literal> clause_lits;

for (unsigned j = 0; j < size; j++) {
clause_lits.push_back(test_literals[lit_dist(rng)]);
}

clause* c = allocator.mk_clause(size, clause_lits.data(), i % 2 == 0);
test_clauses.push_back(c);
}

std::cout << "Created " << test_clauses.size() << " test clauses" << std::endl;

// Check cache alignment effectiveness
unsigned aligned_count = 0;
for (clause* c : test_clauses) {
if (c->is_cache_aligned()) aligned_count++;
}
std::cout << "Cache-aligned clauses: " << aligned_count << "/" << test_clauses.size()
<< " (" << (100.0 * aligned_count / test_clauses.size()) << "%)" << std::endl;
}

void benchmark_contains_literal() {
std::uniform_int_distribution<unsigned> clause_dist(0, test_clauses.size() - 1);
std::uniform_int_distribution<unsigned> lit_dist(0, test_literals.size() - 1);

const int iterations = 1000000;
unsigned hits = 0;

auto start = std::chrono::high_resolution_clock::now();

for (int i = 0; i < iterations; i++) {
clause* c = test_clauses[clause_dist(rng)];
literal l = test_literals[lit_dist(rng)];
if (c->contains(l)) hits++;
contains_literal_calls++;
}

auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);

std::cout << "contains(literal) benchmark:" << std::endl;
std::cout << " Iterations: " << iterations << std::endl;
std::cout << " Hits: " << hits << " (" << (100.0 * hits / iterations) << "%)" << std::endl;
std::cout << " Total time: " << duration.count() / 1e6 << " ms" << std::endl;
std::cout << " Average time per call: " << (double)duration.count() / iterations << " ns" << std::endl;
}

void benchmark_contains_var() {
std::uniform_int_distribution<unsigned> clause_dist(0, test_clauses.size() - 1);
std::uniform_int_distribution<unsigned> var_dist(0, 999);

const int iterations = 1000000;
unsigned hits = 0;

auto start = std::chrono::high_resolution_clock::now();

for (int i = 0; i < iterations; i++) {
clause* c = test_clauses[clause_dist(rng)];
bool_var v = var_dist(rng);
if (c->contains(v)) hits++;
contains_var_calls++;
}

auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);

std::cout << "contains(bool_var) benchmark:" << std::endl;
std::cout << " Iterations: " << iterations << std::endl;
std::cout << " Hits: " << hits << " (" << (100.0 * hits / iterations) << "%)" << std::endl;
std::cout << " Total time: " << duration.count() / 1e6 << " ms" << std::endl;
std::cout << " Average time per call: " << (double)duration.count() / iterations << " ns" << std::endl;
}

void benchmark_satisfied_by() {
// Create a simple model
model m(1000);
std::uniform_int_distribution<int> bool_dist(0, 2); // 0=false, 1=true, 2=undef

for (unsigned v = 0; v < 1000; v++) {
int val = bool_dist(rng);
m[v] = (val == 0) ? l_false : (val == 1) ? l_true : l_undef;
}

std::uniform_int_distribution<unsigned> clause_dist(0, test_clauses.size() - 1);

const int iterations = 500000;
unsigned satisfied = 0;

auto start = std::chrono::high_resolution_clock::now();

for (int i = 0; i < iterations; i++) {
clause* c = test_clauses[clause_dist(rng)];
if (c->satisfied_by(m)) satisfied++;
satisfied_calls++;
}

auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);

std::cout << "satisfied_by(model) benchmark:" << std::endl;
std::cout << " Iterations: " << iterations << std::endl;
std::cout << " Satisfied: " << satisfied << " (" << (100.0 * satisfied / iterations) << "%)" << std::endl;
std::cout << " Total time: " << duration.count() / 1e6 << " ms" << std::endl;
std::cout << " Average time per call: " << (double)duration.count() / iterations << " ns" << std::endl;
}

void run_all_benchmarks() {
std::cout << "=== Z3 Clause Management Performance Test ===" << std::endl;
std::cout << "Testing cache-friendly optimizations" << std::endl;
std::cout << std::endl;

benchmark_contains_literal();
std::cout << std::endl;

benchmark_contains_var();
std::cout << std::endl;

benchmark_satisfied_by();
std::cout << std::endl;

// Memory usage info
std::cout << "Memory usage:" << std::endl;
std::cout << " Clause allocator size: " << allocator.get_allocation_size() << " bytes" << std::endl;
std::cout << " Average clause size: " << (double)allocator.get_allocation_size() / test_clauses.size() << " bytes" << std::endl;
}
};

int main() {
try {
clause_performance_test test;
test.run_all_benchmarks();
return 0;
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
}
114 changes: 90 additions & 24 deletions src/sat/sat_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,59 @@ namespace sat {
}

bool clause::contains(literal l) const {
for (literal l2 : *this)
if (l2 == l)
return true;
return false;
// Prefetch the literals array for better cache performance
#ifdef __GNUC__
if (m_size > 0) __builtin_prefetch(m_lits, 0, 3);
#endif

// Unroll small cases for better branch prediction
switch (m_size) {
case 0: return false;
case 1: return m_lits[0] == l;
case 2: return m_lits[0] == l || m_lits[1] == l;
case 3: return m_lits[0] == l || m_lits[1] == l || m_lits[2] == l;
default:
// For larger clauses, use unrolled loop for better cache usage
unsigned i = 0;
for (; i + 3 < m_size; i += 4) {
if (m_lits[i] == l || m_lits[i+1] == l ||
m_lits[i+2] == l || m_lits[i+3] == l)
return true;
}
for (; i < m_size; i++) {
if (m_lits[i] == l)
return true;
}
return false;
}
}

bool clause::contains(bool_var v) const {
for (literal l : *this)
if (l.var() == v)
return true;
return false;
// Prefetch the literals array for better cache performance
#ifdef __GNUC__
if (m_size > 0) __builtin_prefetch(m_lits, 0, 3);
#endif

// Unroll small cases for better branch prediction
switch (m_size) {
case 0: return false;
case 1: return m_lits[0].var() == v;
case 2: return m_lits[0].var() == v || m_lits[1].var() == v;
case 3: return m_lits[0].var() == v || m_lits[1].var() == v || m_lits[2].var() == v;
default:
// For larger clauses, use unrolled loop for better cache usage
unsigned i = 0;
for (; i + 3 < m_size; i += 4) {
if (m_lits[i].var() == v || m_lits[i+1].var() == v ||
m_lits[i+2].var() == v || m_lits[i+3].var() == v)
return true;
}
for (; i < m_size; i++) {
if (m_lits[i].var() == v)
return true;
}
return false;
}
}

void clause::elim(literal l) {
Expand Down Expand Up @@ -101,17 +143,41 @@ namespace sat {
}

bool clause::satisfied_by(model const & m) const {
for (literal l : *this) {
if (l.sign()) {
if (m[l.var()] == l_false)
// Prefetch the literals array for better cache performance
#ifdef __GNUC__
if (m_size > 0) __builtin_prefetch(m_lits, 0, 3);
#endif

// Unroll small cases for better branch prediction
switch (m_size) {
case 0: return false;
case 1: {
literal l = m_lits[0];
return l.sign() ? (m[l.var()] == l_false) : (m[l.var()] == l_true);
}
case 2: {
literal l0 = m_lits[0], l1 = m_lits[1];
return (l0.sign() ? (m[l0.var()] == l_false) : (m[l0.var()] == l_true)) ||
(l1.sign() ? (m[l1.var()] == l_false) : (m[l1.var()] == l_true));
}
default:
// For larger clauses, use unrolled loop
unsigned i = 0;
for (; i + 3 < m_size; i += 4) {
literal l0 = m_lits[i], l1 = m_lits[i+1], l2 = m_lits[i+2], l3 = m_lits[i+3];
if ((l0.sign() ? (m[l0.var()] == l_false) : (m[l0.var()] == l_true)) ||
(l1.sign() ? (m[l1.var()] == l_false) : (m[l1.var()] == l_true)) ||
(l2.sign() ? (m[l2.var()] == l_false) : (m[l2.var()] == l_true)) ||
(l3.sign() ? (m[l3.var()] == l_false) : (m[l3.var()] == l_true)))
return true;
}
else {
if (m[l.var()] == l_true)
for (; i < m_size; i++) {
literal l = m_lits[i];
if (l.sign() ? (m[l.var()] == l_false) : (m[l.var()] == l_true))
return true;
}
return false;
}
return false;
}

clause_offset clause::get_new_offset() const {
Expand Down Expand Up @@ -224,19 +290,19 @@ namespace sat {
}

bool clause_wrapper::contains(literal l) const {
unsigned sz = size();
for (unsigned i = 0; i < sz; i++)
if (operator[](i) == l)
return true;
return false;
if (is_binary()) {
return operator[](0) == l || operator[](1) == l;
} else {
return get_clause()->contains(l);
}
}

bool clause_wrapper::contains(bool_var v) const {
unsigned sz = size();
for (unsigned i = 0; i < sz; i++)
if (operator[](i).var() == v)
return true;
return false;
if (is_binary()) {
return operator[](0).var() == v || operator[](1).var() == v;
} else {
return get_clause()->contains(v);
}
}

std::ostream & operator<<(std::ostream & out, clause_wrapper const & c) {
Expand Down
8 changes: 7 additions & 1 deletion src/sat/sat_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace sat {

std::ostream & operator<<(std::ostream & out, clause const & c);

class clause {
class alignas(16) clause {
friend class clause_allocator;
friend class tmp_clause;
unsigned m_id;
Expand Down Expand Up @@ -103,6 +103,12 @@ namespace sat {

bool on_reinit_stack() const { return m_reinit_stack; }
void set_reinit_stack(bool f) { m_reinit_stack = f; }

// Performance monitoring for cache-friendly optimizations
static unsigned get_cache_line_size() { return 64; } // Typical cache line size
bool is_cache_aligned() const {
return reinterpret_cast<uintptr_t>(this) % 16 == 0;
}
};

std::ostream & operator<<(std::ostream & out, clause_vector const & cs);
Expand Down
Loading