diff --git a/.gitignore b/.gitignore index d59c8e9c6..dab930a6a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ aries_fzn/share/aries_fzn __pycache__/ *.profraw lcov.info +profile.json.gz diff --git a/solver/src/collections/ref_store.rs b/solver/src/collections/ref_store.rs index 9e0fe0534..2f32c4790 100644 --- a/solver/src/collections/ref_store.rs +++ b/solver/src/collections/ref_store.rs @@ -399,6 +399,9 @@ impl Default for RefMap { impl RefMap { pub fn insert(&mut self, k: K, v: V) { let index = k.into(); + if index > self.entries.len() { + self.entries.reserve(index - self.entries.len()); + } while self.entries.len() <= index { self.entries.push(None); } @@ -454,6 +457,7 @@ impl RefMap { self.entries[index].as_mut() } } + pub fn get_or_insert(&mut self, k: K, default: impl FnOnce() -> V) -> &V { if !self.contains(k) { self.insert(k, default()) @@ -562,6 +566,11 @@ impl IterableRefMap { self.map.insert(k, v) } + pub fn remove(&mut self, k: K) { + self.map.remove(k); + self.keys.retain(|e| *e != k); + } + /// Removes all elements from the Map. #[inline(never)] pub fn clear(&mut self) { diff --git a/solver/src/collections/set.rs b/solver/src/collections/set.rs index c69fd90d8..6b12dc8e8 100644 --- a/solver/src/collections/set.rs +++ b/solver/src/collections/set.rs @@ -61,6 +61,16 @@ impl Default for RefSet { } } +impl FromIterator for IterableRefSet { + fn from_iter>(iter: T) -> Self { + let mut set = Self::new(); + for i in iter { + set.insert(i); + } + set + } +} + /// A set of values that can be converted into small unsigned integers. /// This extends `RefSet` with a vector of all elements of the set, allowing for fast iteration /// and clearing. @@ -89,6 +99,10 @@ impl IterableRefSet { self.set.insert(k, ()); } + pub fn remove(&mut self, k: K) { + self.set.remove(k); + } + pub fn clear(&mut self) { self.set.clear() } diff --git a/solver/src/core/state/domain.rs b/solver/src/core/state/domain.rs index 27d79721f..1a486b353 100644 --- a/solver/src/core/state/domain.rs +++ b/solver/src/core/state/domain.rs @@ -1,5 +1,5 @@ use crate::{ - core::{cst_int_to_long, IntCst, LongCst, INT_CST_MAX, INT_CST_MIN}, + core::{cst_int_to_long, IntCst, Lit, LongCst, INT_CST_MAX, INT_CST_MIN}, model::lang::Rational, }; use std::fmt::{Display, Formatter}; @@ -53,6 +53,11 @@ impl IntDomain { pub fn disjoint(&self, other: &IntDomain) -> bool { self.ub < other.lb || other.ub < self.lb } + + pub fn entails(&self, literal: Lit) -> bool { + literal.svar().is_plus() && literal.variable().leq(self.ub).entails(literal) + || literal.svar().is_minus() && literal.variable().geq(self.lb).entails(literal) + } } impl std::ops::Mul for IntDomain { diff --git a/solver/src/core/state/snapshot.rs b/solver/src/core/state/snapshot.rs index 0986878c9..859fdde22 100644 --- a/solver/src/core/state/snapshot.rs +++ b/solver/src/core/state/snapshot.rs @@ -2,6 +2,8 @@ use crate::backtrack::{DecLvl, EventIndex}; use crate::core::state::{Domains, Event, Term}; use crate::core::{IntCst, Lit, SignedVar}; +use super::IntDomain; + /// View of the domains at a given point in time. /// /// This is primarily intended to query the state as it was when a literal was inferred. @@ -60,6 +62,11 @@ impl<'a> DomainsSnapshot<'a> { -self.ub(-var.into()) } + pub fn int_domain(&self, var: impl Into) -> IntDomain { + let (lb, ub) = self.bounds(var.into()); + IntDomain::new(lb, ub) + } + pub fn bounds(&self, var: impl Into) -> (IntCst, IntCst) { let var = var.into(); (self.lb(var), self.ub(var)) diff --git a/solver/src/reasoners/eq_alt/constraints.rs b/solver/src/reasoners/eq_alt/constraints.rs new file mode 100644 index 000000000..cff9b5ab2 --- /dev/null +++ b/solver/src/reasoners/eq_alt/constraints.rs @@ -0,0 +1,151 @@ +use hashbrown::HashMap; +use std::fmt::Debug; + +use crate::{ + collections::ref_store::RefVec, + core::{literals::Watches, Lit}, + create_ref_type, +}; + +use super::{node::Node, relation::EqRelation}; + +// TODO: Identical to STN, maybe identify some other common logic (such as identity) and bump up to reasoner module + +/// Enabling information for a propagator. +/// A propagator should be enabled iff both literals `active` and `valid` are true. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct Enabler { + /// A literal that is true (but not necessarily present) when the propagator must be active if present + pub active: Lit, + /// A literal that is true when the propagator is within its validity scope, i.e., + /// when is known to be sound to propagate a change from the source to the target. + /// + /// In the simplest case, we have `valid = presence(active)` since by construction + /// `presence(active)` is true iff both variables of the constraint are present. + /// + /// `valid` might a more specific literal but always with the constraints that + /// `presence(active) => valid` + pub valid: Lit, +} + +impl Enabler { + pub fn new(active: Lit, valid: Lit) -> Enabler { + Enabler { active, valid } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ActivationEvent { + /// the edge to enable + pub constraint_id: ConstraintId, +} + +impl ActivationEvent { + pub fn new(constraint_id: ConstraintId) -> Self { + Self { constraint_id } + } +} + +create_ref_type!(ConstraintId); + +impl Debug for ConstraintId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Propagator {}", self.to_u32()) + } +} + +/// One direction of a semi-reified eq or neq constraint. +/// +/// Formally enabler.active => a (relation) b +/// with enabler.valid = presence(b) => presence(a) +#[derive(Clone, Hash, Debug, PartialEq, Eq)] +pub struct Constraint { + pub a: Node, + pub b: Node, + pub relation: EqRelation, + pub enabler: Enabler, +} + +impl Constraint { + pub fn new(a: Node, b: Node, relation: EqRelation, active: Lit, valid: Lit) -> Self { + Self { + a, + b, + relation, + enabler: Enabler::new(active, valid), + } + } + + pub fn new_pair(a: Node, b: Node, relation: EqRelation, active: Lit, ab_valid: Lit, ba_valid: Lit) -> (Self, Self) { + ( + Self::new(a, b, relation, active, ab_valid), + Self::new(b, a, relation, active, ba_valid), + ) + } +} + +// #[derive(Debug, Clone, Copy)] +// enum Event { +// PropagatorAdded, +// WatchAdded(ConstraintId, Lit), +// } + +/// Data structures to store propagators. +#[derive(Clone, Default)] +pub struct ConstraintStore { + constraints: RefVec, + // constraint_lookup: HashMap<(Node, Node), Vec>, + in_constraints: HashMap>, + out_constraints: HashMap>, + watches: Watches<(Enabler, ConstraintId)>, + // trail: Trail, +} + +impl ConstraintStore { + pub fn add_constraint(&mut self, constraint: Constraint) -> ConstraintId { + // assert_eq!(self.current_decision_level(), DecLvl::ROOT); + // self.trail.push(Event::PropagatorAdded); + let id = self.constraints.len().into(); + self.constraints.push(constraint.clone()); + self.out_constraints + .entry(constraint.a) + .and_modify(|v| v.push(id)) + .or_insert(vec![id]); + self.in_constraints + .entry(constraint.b) + .and_modify(|v| v.push(id)) + .or_insert(vec![id]); + id + } + + pub fn add_watch(&mut self, id: ConstraintId, literal: Lit) { + let enabler = self.constraints[id].enabler; + self.watches.add_watch((enabler, id), literal); + // self.trail.push(Event::WatchAdded(id, literal)); + } + + pub fn get_constraint(&self, constraint_id: ConstraintId) -> &Constraint { + &self.constraints[constraint_id] + } + + // Get valid propagators by source and target + // pub fn get_constraints_between(&self, source: Node, target: Node) -> Vec { + // self.constraint_lookup.get(&(source, target)).cloned().unwrap_or(vec![]) + // } + + pub fn get_out_constraints(&self, source: Node) -> Vec { + self.out_constraints.get(&source).cloned().unwrap_or_default() + } + + pub fn get_in_constraints(&self, source: Node) -> Vec { + self.in_constraints.get(&source).cloned().unwrap_or_default() + } + + pub fn enabled_by(&self, literal: Lit) -> impl Iterator + '_ { + self.watches.watches_on(literal) + } + + pub fn iter(&self) -> impl Iterator + use<'_> { + self.constraints.entries() + } +} diff --git a/solver/src/reasoners/eq_alt/graph/adj_list.rs b/solver/src/reasoners/eq_alt/graph/adj_list.rs new file mode 100644 index 000000000..ae0941a8c --- /dev/null +++ b/solver/src/reasoners/eq_alt/graph/adj_list.rs @@ -0,0 +1,68 @@ +use std::fmt::{Debug, Formatter}; + +use crate::collections::ref_store::IterableRefMap; + +use super::{Edge, NodeId}; + +#[derive(Default, Clone)] +pub struct EqAdjList(IterableRefMap>); + +impl Debug for EqAdjList { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f)?; + for (node, edges) in self.0.entries() { + if !edges.is_empty() { + writeln!(f, "{:?}:", node)?; + for edge in edges { + writeln!(f, " -> {:?} {:?}", edge.target, edge)?; + } + } + } + Ok(()) + } +} + +impl EqAdjList { + /// Insert a node if not present + fn insert_node(&mut self, node: NodeId) { + if !self.0.contains(node) { + self.0.insert(node, Default::default()); + } + } + + /// Possibly insert an edge and both nodes + /// Returns true if edge was inserted + pub fn insert_edge(&mut self, edge: Edge) -> bool { + self.insert_node(edge.source); + self.insert_node(edge.target); + let edges = self.get_edges_mut(edge.source).unwrap(); + if edges.contains(&edge) { + false + } else { + edges.push(edge); + true + } + } + + pub fn iter_edges(&self, node: NodeId) -> impl Iterator { + self.0.get(node).into_iter().flat_map(|v| v.iter()) + } + + pub fn get_edges_mut(&mut self, node: NodeId) -> Option<&mut Vec> { + self.0.get_mut(node) + } + + pub fn iter_all_edges(&self) -> impl Iterator + use<'_> { + self.0.entries().flat_map(|(_, e)| e.iter().cloned()) + } + + pub fn iter_nodes(&self) -> impl Iterator + use<'_> { + self.0.entries().map(|(n, _)| n) + } + + pub fn remove_edge(&mut self, edge: Edge) { + if let Some(set) = self.0.get_mut(edge.source) { + set.retain(|e| *e != edge) + } + } +} diff --git a/solver/src/reasoners/eq_alt/graph/mod.rs b/solver/src/reasoners/eq_alt/graph/mod.rs new file mode 100644 index 000000000..66029a8c8 --- /dev/null +++ b/solver/src/reasoners/eq_alt/graph/mod.rs @@ -0,0 +1,827 @@ +/// This module exports an adjacency list graph of the active constraints, +/// methods to transform and traverse it, and the method paths_requiring(edge). +use std::array; +use std::cell::{RefCell, RefMut}; +use std::fmt::{Debug, Display}; +use std::hash::Hash; + +use hashbrown::HashSet; +use itertools::Itertools; +use node_store::NodeStore; +use transforms::{EqExt, EqNeqExt, EqNode, FilterExt}; +use traversal::{Edge as _, Graph, Scratch}; + +use crate::backtrack::{Backtrack, DecLvl, Trail}; +use crate::core::Lit; +use crate::create_ref_type; +use crate::reasoners::eq_alt::graph::adj_list::EqAdjList; + +use super::constraints::Constraint; +use super::node::Node; +use super::relation::EqRelation; + +mod adj_list; +mod node_store; +pub mod transforms; +pub mod traversal; + +create_ref_type!(NodeId); +pub use node_store::GroupId; + +impl Display for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self, f) + } +} + +/// A directed edge between two nodes (identified by ids) +/// with an associated relation and activity literal. +#[derive(PartialEq, Eq, Copy, Clone, Debug, Hash)] +pub struct Edge { + pub source: NodeId, + pub target: NodeId, + pub active: Lit, + pub relation: EqRelation, +} + +impl Edge { + fn new(source: NodeId, target: NodeId, active: Lit, relation: EqRelation) -> Self { + Self { + source, + target, + active, + relation, + } + } + + /// Swaps source and target. Useful to convert from outgoing-graph edge and incoming-graph edge. + fn reverse(&self) -> Self { + Edge { + source: self.target, + target: self.source, + ..*self + } + } +} + +/// A backtrackable event affecting the graph. +#[derive(Clone)] +enum Event { + EdgeAdded(Edge), + GroupEdgeAdded(Edge), + GroupEdgeRemoved(Edge), +} + +type DfsScratch = Scratch>; +thread_local! { + /// A reusable bit of memory to be used by graph traversal. + static SCRATCHES: [RefCell; 4] = array::from_fn(|_| Default::default()); +} + +/// Run f with any number of scratches (max determined by SCRATCHES variables) +/// Array destructuring syntax allows you to specify the number and get multiple as mut +pub fn with_scratches(f: F) -> R +where + F: FnOnce([RefMut<'_, DfsScratch>; N]) -> R, +{ + SCRATCHES.with(|cells| { + f(cells[0..N] + .iter() + .map(|cell| cell.borrow_mut()) + .collect_array() + .unwrap()) + }) +} + +/// An adjacency list representation of a directed "equality graph" +/// where each edge has an eq/neq relation and an activity literal. +/// +/// 4 adjacency lists are stored in memory: +/// - Outgoing (forward) +/// - Incoming (reverse/backward) +/// - Grouped outgoing (SCCs of equal nodes are merged into one) +/// - Grouped incoming +/// +/// Notable methods include path_requiring(edge) which is useful for propagation. +/// +/// It is also possible to transform and traverse the graph with +/// `graph.outgoing_grouped.eq_neq().filter(...).traverse(source, Default::default()).find(...)` for example. +#[derive(Clone, Default)] +pub(super) struct DirectedEqualityGraph { + pub node_store: NodeStore, + // These are pub to allow graph traversal API at theory level + pub outgoing: EqAdjList, + pub incoming: EqAdjList, + pub outgoing_grouped: EqAdjList, + pub incoming_grouped: EqAdjList, + trail: Trail, +} + +impl DirectedEqualityGraph { + pub fn new() -> Self { + Default::default() + } + + /// Add node to graph if not present. Returns the id of the Node. + pub fn insert_node(&mut self, node: Node) -> NodeId { + self.node_store + .get_id(&node) + .unwrap_or_else(|| self.node_store.insert_node(node)) + } + + /// Get node from id. + /// + /// # Panics + /// + /// Panics if node with `id` is not in graph. + pub fn get_node(&self, id: NodeId) -> Node { + self.node_store.get_node(id) + } + + pub fn get_id(&self, node: &Node) -> Option { + self.node_store.get_id(node) + } + + pub fn get_group_id(&self, id: NodeId) -> GroupId { + self.node_store.get_group_id(id) + } + + #[allow(unused)] + pub fn get_group(&self, id: GroupId) -> Vec { + self.node_store.get_group(id) + } + + pub fn get_group_nodes(&self, id: GroupId) -> Vec { + self.node_store.get_group_nodes(id) + } + + /// Merge together two nodes when they are determined to belong to the same Eq SCC. + pub fn merge(&mut self, ids: (NodeId, NodeId)) { + let child = self.get_group_id(ids.0); + let parent = self.get_group_id(ids.1); + + // Merge NodeIds + self.node_store.merge(child, parent); + + // For each edge that goes out of the child group + for edge in self.outgoing_grouped.iter_edges(child.into()).cloned().collect_vec() { + self.trail.push(Event::GroupEdgeRemoved(edge)); + + // Remove it from both adjacency lists + self.outgoing_grouped.remove_edge(edge); + self.incoming_grouped.remove_edge(edge.reverse()); + + // Modify it to have the parent group as a source + let new_edge = Edge { + source: parent.into(), + ..edge + }; + // Avoid adding edges from a group into the same group + if new_edge.source == new_edge.target { + continue; + } + + // Possibly insert it back in + let added = self.outgoing_grouped.insert_edge(new_edge); + assert_eq!(added, self.incoming_grouped.insert_edge(new_edge.reverse())); + if added { + self.trail.push(Event::GroupEdgeAdded(new_edge)); + } + } + + // Same for incoming edges + for edge in self.incoming_grouped.iter_edges(child.into()).cloned().collect_vec() { + let edge = edge.reverse(); + self.trail.push(Event::GroupEdgeRemoved(edge)); + self.outgoing_grouped.remove_edge(edge); + self.incoming_grouped.remove_edge(edge.reverse()); + + let new_edge = Edge { + target: parent.into(), + ..edge + }; + if new_edge.source == new_edge.target { + continue; + } + + let added = self.outgoing_grouped.insert_edge(new_edge); + assert_eq!(added, self.incoming_grouped.insert_edge(new_edge.reverse())); + if added { + self.trail.push(Event::GroupEdgeAdded(new_edge)); + } + } + } + + /// Returns an edge from a propagator without adding it to the graph. + /// + /// Adds the nodes to the graph if they are not present. + pub fn create_edge(&mut self, prop: &Constraint) -> Edge { + let source_id = self.insert_node(prop.a); + let target_id = self.insert_node(prop.b); + Edge::new(source_id, target_id, prop.enabler.active, prop.relation) + } + + /// Adds an edge to the graph. + pub fn add_edge(&mut self, edge: Edge) { + self.trail.push(Event::EdgeAdded(edge)); + self.outgoing.insert_edge(edge); + self.incoming.insert_edge(edge.reverse()); + let grouped_edge = Edge { + source: self.get_group_id(edge.source).into(), + target: self.get_group_id(edge.target).into(), + ..edge + }; + self.trail.push(Event::GroupEdgeAdded(grouped_edge)); + self.outgoing_grouped.insert_edge(grouped_edge); + self.incoming_grouped.insert_edge(grouped_edge.reverse()); + } + + pub fn iter_nodes(&self) -> impl Iterator + use<'_> { + self.outgoing.iter_nodes().map(|id| self.node_store.get_node(id)) + } + + /// Get all paths which would require the given edge to exist. + /// + /// For an edge x -==-> y, returns a vec of all pairs (w, z) such that w -=-> z or w -!=-> z in G union x -=-> y, but not in G. + /// + /// For an edge x -!=-> y, returns a vec of all pairs (w, z) such that w -!=> z in G union x -!=-> y, but not in G. + /// propagator nodes must already be added + pub fn paths_requiring(&self, edge: Edge) -> Vec { + // Convert edge to edge between groups + let edge = Edge { + source: self.node_store.get_group_id(edge.source).into(), + target: self.node_store.get_group_id(edge.target).into(), + ..edge + }; + + match edge.relation { + EqRelation::Eq => self.paths_requiring_eq(edge), + EqRelation::Neq => self.paths_requiring_neq(edge), + } + } + + fn paths_requiring_eq(&self, edge: Edge) -> Vec { + debug_assert_eq!(edge.relation, EqRelation::Eq); + + with_scratches(|[mut s1, mut s2, mut s3, mut s4]| { + // Traverse backwards from target to find reachable predecessors + let mut t = self + .incoming_grouped + .eq_neq() + .traverse_dfs(EqNode::new(edge.target), &mut s1); + // If there is already a path from source to target, no paths are created + if t.any(|n| n == EqNode(edge.source, EqRelation::Eq)) { + return Vec::new(); + } + + let reachable_preds = t.visited(); + + // Do the same for reachable successors + let reachable_succs = self + .outgoing_grouped + .eq_neq() + .reachable(EqNode::new(edge.source), &mut s2); + debug_assert!(!reachable_succs.contains(EqNode::new(edge.target))); + + // Traverse backwards from the source excluding nodes which can reach the target + let predecessors = self + .incoming_grouped + .eq_neq() + .filter(|_, e| !reachable_preds.contains(e.target())) + .traverse_dfs(EqNode::new(edge.source), &mut s3); + + // Traverse forward from the target excluding nodes which can be reached by the source + let successors = self + .outgoing_grouped + .eq_neq() + .filter(|_, e| !reachable_succs.contains(e.target())) + .traverse_dfs(EqNode::new(edge.target), &mut s4) + .collect_vec(); + + // A cartesian product between predecessors which cannot reach the target and successors which cannot be reached by source + // is equivalent to the set of paths which require the addition of this edge to exist. + predecessors + .into_iter() + .cartesian_product(successors) + .filter_map(|(source, target)| { + // pred id and succ id are GroupIds since all above graph traversals are on MergedGraphs + source.path_to(&target) + }) + .collect_vec() + }) + } + + fn paths_requiring_neq(&self, edge: Edge) -> Vec { + debug_assert_eq!(edge.relation, EqRelation::Neq); + + // Same principle as Eq, but the logic is a little more complicated + // We want to exclude predecessors reachable with Eq and successors reachable with Neq first + // then the opposite + with_scratches(|[mut s1, mut s2, mut s3, mut s4]| { + // Reachable sets + let mut t = self + .incoming_grouped + .eq_neq() + .traverse_dfs(EqNode::new(edge.target), &mut s1); + if t.any(|n| n == EqNode(edge.source, EqRelation::Neq)) { + return Vec::new(); + } + let reachable_preds = t.visited(); + + let reachable_succs = self + .outgoing_grouped + .eq_neq() + .reachable(EqNode::new(edge.source), &mut s2); + + let neq_filtered_successors = self + .outgoing_grouped + .eq() + .filter(|_, e| !reachable_succs.contains(EqNode(e.target(), EqRelation::Neq))) + .traverse_dfs(edge.target, &mut s3) + .collect_vec(); + + let eq_filtered_successors = self + .outgoing_grouped + .eq() + .filter(|_, e| !reachable_succs.contains(EqNode(e.target(), EqRelation::Eq))) + .traverse_dfs(edge.target, &mut s3) + .collect_vec(); + + let eq_filtered_predecessors = self + .incoming_grouped + .eq() + .filter(|_, e| !reachable_preds.contains(EqNode(e.target(), EqRelation::Eq))) + .traverse_dfs(edge.source, &mut s3); + + let neq_filtered_predecessors = self + .incoming_grouped + .eq() + .filter(|_, e| !reachable_preds.contains(EqNode(e.target(), EqRelation::Neq))) + .traverse_dfs(edge.source, &mut s4); + + let create_path = + |(source, target): (NodeId, NodeId)| -> Path { Path::new(source, target, EqRelation::Neq) }; + + neq_filtered_predecessors + .cartesian_product(eq_filtered_successors) + .map(create_path) + .skip(1) + .chain( + eq_filtered_predecessors + .cartesian_product(neq_filtered_successors) + .map(create_path), + ) + .collect() + }) + } + + #[allow(unused)] + pub fn to_graphviz(&self) -> String { + let mut strings = vec!["Ungrouped: ".to_string(), "digraph {".to_string()]; + for e in self.outgoing.iter_all_edges() { + strings.push(format!( + " {} -> {} [label=\"{} ({:?})\"]", + e.source.to_u32(), + e.target.to_u32(), + e.relation, + e.active + )); + } + strings.push("}".to_string()); + strings.join("\n") + } + + #[allow(unused)] + pub fn to_graphviz_grouped(&self) -> String { + let mut strings = vec!["Grouped: ".to_string(), "digraph {".to_string()]; + for e in self.outgoing_grouped.iter_all_edges() { + strings.push(format!( + " {} -> {} [label=\"{} ({:?})\"]", + e.source.to_u32(), + e.target.to_u32(), + e.relation, + e.active + )); + } + strings.push("}".to_string()); + strings.join("\n") + } + + #[allow(unused)] + pub fn print_merge_statistics(&self) { + println!("Total nodes: {}", self.node_store.len()); + println!("Total groups: {}", self.node_store.count_groups()); + println!("Outgoing edges: {}", self.outgoing.iter_all_edges().count()); + println!( + "Outgoing_grouped edges: {}", + self.outgoing_grouped.iter_all_edges().count() + ); + } + + /// Check that nodes that are not group representatives are not group reps + #[allow(unused)] + pub fn verify_grouping(&self) { + let groups = self.node_store.groups().into_iter().collect::>(); + for node in self.node_store.nodes() { + if groups.contains(&GroupId::from(node)) { + continue; + } + assert!(self.outgoing_grouped.iter_edges(node).all(|_| false)); + assert!(self.incoming_grouped.iter_edges(node).all(|_| false)); + } + } +} + +impl Backtrack for DirectedEqualityGraph { + fn save_state(&mut self) -> DecLvl { + self.node_store.save_state(); + self.trail.save_state() + } + + fn num_saved(&self) -> u32 { + self.trail.num_saved() + } + + fn restore_last(&mut self) { + self.node_store.restore_last(); + self.trail.restore_last_with(|event| match event { + Event::EdgeAdded(edge) => { + self.outgoing.remove_edge(edge); + self.incoming.remove_edge(edge.reverse()); + } + Event::GroupEdgeAdded(edge) => { + self.outgoing_grouped.remove_edge(edge); + self.incoming_grouped.remove_edge(edge.reverse()); + } + Event::GroupEdgeRemoved(edge) => { + self.outgoing_grouped.insert_edge(edge); + self.incoming_grouped.insert_edge(edge.reverse()); + } + }); + } +} + +/// Directed pair of nodes with a == or != relation +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct Path { + pub source_id: GroupId, + pub target_id: GroupId, + pub relation: EqRelation, +} + +impl Debug for Path { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let &Path { + source_id, + target_id, + relation, + } = self; + write!(f, "{source_id:?} --{relation}--> {target_id:?}") + } +} + +impl Path { + pub fn new(source: impl Into, target: impl Into, relation: EqRelation) -> Self { + Self { + source_id: source.into(), + target_id: target.into(), + relation, + } + } + + /// Returns true if the path is source -==-> source + pub fn redundant(&self) -> bool { + self.source_id == self.target_id && self.relation == EqRelation::Eq + } + + /// Returns true if the path is source -!=-> source + pub fn contradictory(&self) -> bool { + self.source_id == self.target_id && self.relation == EqRelation::Neq + } +} + +#[cfg(test)] +mod tests { + use EqRelation::*; + + use crate::{collections::set::IterableRefSet, core::IntCst}; + + use super::{traversal::PathStore, *}; + + macro_rules! assert_eq_unordered_unique { + ($left:expr, $right:expr $(,)?) => {{ + use std::collections::HashSet; + let left = $left.into_iter().collect_vec(); + let right = $right.into_iter().collect_vec(); + assert!( + left.clone().into_iter().all_unique(), + "{:?} is duplicated in left", + left.clone().into_iter().duplicates().collect_vec() + ); + assert!( + right.clone().into_iter().all_unique(), + "{:?} is duplicated in right", + right.clone().into_iter().duplicates().collect_vec() + ); + let l_set: HashSet<_> = left.into_iter().collect(); + let r_set: HashSet<_> = right.into_iter().collect(); + + let lr_diff: HashSet<_> = l_set.difference(&r_set).cloned().collect(); + let rl_diff: HashSet<_> = r_set.difference(&l_set).cloned().collect(); + + assert!(lr_diff.is_empty(), "Found in left but not in right: {:?}", lr_diff); + assert!(rl_diff.is_empty(), "Found in right but not in left: {:?}", rl_diff); + }}; + } + + fn prop(src: IntCst, tgt: IntCst, relation: EqRelation) -> Constraint { + Constraint::new(Node::Val(src), Node::Val(tgt), relation, Lit::TRUE, Lit::TRUE) + } + + fn id(g: &DirectedEqualityGraph, node: IntCst) -> NodeId { + g.get_id(&Node::Val(node)).unwrap() + } + + fn eqn(g: &DirectedEqualityGraph, node: IntCst, r: EqRelation) -> EqNode { + EqNode(id(g, node), r) + } + + fn edge(g: &DirectedEqualityGraph, src: IntCst, tgt: IntCst, relation: EqRelation) -> Edge { + Edge::new( + g.get_id(&Node::Val(src)).unwrap(), + g.get_id(&Node::Val(tgt)).unwrap(), + Lit::TRUE, + relation, + ) + } + + fn path(g: &DirectedEqualityGraph, src: IntCst, tgt: IntCst, relation: EqRelation) -> Path { + Path::new( + g.get_id(&Node::Val(src)).unwrap(), + g.get_id(&Node::Val(tgt)).unwrap(), + relation, + ) + } + + /* Copy this into https://magjac.com/graphviz-visual-editor/ + digraph { + 0 -> 1 [label=" ="] + 1 -> 2 [label=" !="] + 1 -> 3 [label=" ="] + 2 -> 4 [label=" !="] + 1 -> 5 [label=" ="] + 5 -> 6 [label=" ="] + 6 -> 0 [label=" !="] + 5 -> 0 [label=" ="] + } + */ + fn instance1() -> DirectedEqualityGraph { + let mut g = DirectedEqualityGraph::new(); + for prop in [ + prop(0, 1, Eq), + prop(1, 2, Neq), + prop(1, 3, Eq), + prop(2, 4, Neq), + prop(1, 5, Eq), + prop(5, 6, Eq), + prop(6, 0, Neq), + prop(5, 0, Eq), + ] { + let edge = g.create_edge(&prop); + g.add_edge(edge); + } + g + } + + /* Instance focused on merging + digraph { + 0 -> 1 [label=" ="] + 1 -> 0 [label=" ="] + 1 -> 2 [label=" ="] + 2 -> 0 [label=" ="] + 2 -> 3 [label=" !="] + 3 -> 4 [label=" ="] + 4 -> 5 [label=" ="] + 5 -> 3 [label=" ="] + 0 -> 5 [label=" !="] + 4 -> 1 [label=" ="] + } + */ + fn instance2() -> DirectedEqualityGraph { + let mut g = DirectedEqualityGraph::new(); + for prop in [ + prop(0, 1, Eq), + prop(1, 0, Eq), + prop(1, 2, Eq), + prop(2, 0, Eq), + prop(2, 3, Neq), + prop(3, 4, Eq), + prop(4, 5, Eq), + prop(5, 3, Eq), + prop(0, 5, Neq), + prop(4, 1, Eq), + ] { + let edge = g.create_edge(&prop); + g.add_edge(edge); + } + g + } + + #[test] + fn test_traversal() { + let g = instance1(); + + with_scratches(|[mut s]| { + let traversal = g.outgoing.eq().traverse_dfs(id(&g, 0), &mut s); + assert_eq_unordered_unique!( + traversal, + vec![id(&g, 0,), id(&g, 1,), id(&g, 3,), id(&g, 5,), id(&g, 6,)], + ); + }); + + with_scratches(|[mut s]| { + let traversal = g.outgoing.eq().traverse_dfs(id(&g, 6), &mut s); + assert_eq_unordered_unique!(traversal, vec![id(&g, 6)]); + }); + + with_scratches(|[mut s]| { + let traversal = g.incoming.eq_neq().traverse_dfs(eqn(&g, 0, Eq), &mut s); + assert_eq_unordered_unique!( + traversal, + vec![ + eqn(&g, 0, Eq), + eqn(&g, 6, Neq), + eqn(&g, 5, Eq), + eqn(&g, 5, Neq), + eqn(&g, 1, Eq), + eqn(&g, 1, Neq), + eqn(&g, 0, Neq), + ], + ); + }); + } + + #[test] + fn test_merging() { + let mut g = instance1(); + g.merge((id(&g, 0), id(&g, 1))); + g.merge((id(&g, 5), id(&g, 1))); + let rep = g.get_group_id(id(&g, 0)); + let Node::Val(rep) = g.get_node(rep.into()) else { + panic!() + }; + assert_eq_unordered_unique!( + g.outgoing_grouped.iter_edges(id(&g, rep)).cloned(), + vec![edge(&g, rep, 6, Eq), edge(&g, rep, 3, Eq), edge(&g, rep, 2, Neq)] + ); + assert_eq_unordered_unique!( + g.incoming_grouped.iter_edges(id(&g, rep)).cloned(), + vec![edge(&g, rep, 6, Neq)] + ); + } + + #[test] + fn test_reduced_path() { + let g = instance2(); + let mut path_store = PathStore::new(); + let target = with_scratches(|[mut scratch]| { + g.outgoing + .eq_neq() + .traverse_dfs(eqn(&g, 0, Eq), &mut scratch) + .record_paths(&mut path_store) + .find(|&EqNode(n, r)| n == id(&g, 4) && r == Neq) + .expect("Path exists") + }); + + with_scratches(|[mut s]| { + g.outgoing + .eq_neq() + .traverse_dfs(eqn(&g, 0, Eq), &mut s) + .record_paths(&mut path_store) + .find(|&EqNode(n, r)| n == id(&g, 4) && r == Neq) + .expect("Path exists"); + }); + + let path1 = vec![edge(&g, 3, 4, Eq), edge(&g, 5, 3, Eq), edge(&g, 0, 5, Neq)]; + let path2 = vec![ + edge(&g, 3, 4, Eq), + edge(&g, 2, 3, Neq), + edge(&g, 1, 2, Eq), + edge(&g, 0, 1, Eq), + ]; + let mut set = IterableRefSet::new(); + let out_path1 = path_store.get_path(target).map(|e| e.0).collect_vec(); + if out_path1 == path1 { + set.insert(eqn(&g, 5, Neq)); + + let mut path_store_2 = PathStore::new(); + + with_scratches(|[mut s]| { + let target = g + .outgoing + .eq_neq() + .filter(|_, e| !set.contains(e.target())) + .traverse_dfs(eqn(&g, 0, Eq), &mut s) + .record_paths(&mut path_store_2) + .find(|&EqNode(n, r)| n == id(&g, 4) && r == Neq) + .expect("Path exists"); + assert_eq!(path_store_2.get_path(target).map(|e| e.0).collect_vec(), path2); + }); + } else if out_path1 == path2 { + set.insert(eqn(&g, 1, Eq)); + + let mut path_store_2 = PathStore::new(); + with_scratches(|[mut s]| { + let target = g + .outgoing + .eq_neq() + .filter(|_, e| !set.contains(e.target())) + .traverse_dfs(eqn(&g, 0, Eq), &mut s) + .record_paths(&mut path_store_2) + .find(|&EqNode(n, r)| n == id(&g, 4) && r == Neq) + .expect("Path exists"); + assert_eq!(path_store_2.get_path(target).map(|e| e.0).collect_vec(), path1); + }); + } + } + + #[test] + fn test_paths_requiring_cycles() { + let mut g = DirectedEqualityGraph::new(); + for i in -3..=3 { + g.insert_node(Node::Val(i)); + } + + g.add_edge(edge(&g, -3, -2, Eq)); + g.add_edge(edge(&g, -2, -1, Eq)); + assert_eq_unordered_unique!( + g.paths_requiring(edge(&g, -1, -3, Eq)), + [ + path(&g, -2, -2, Eq), + path(&g, -1, -3, Eq), + path(&g, -1, -2, Eq), + path(&g, -2, -3, Eq) + ] + ); + g.add_edge(edge(&g, -1, -3, Eq)); + g.merge((id(&g, -1), id(&g, -3))); + g.merge((id(&g, -2), id(&g, -3))); + assert_eq_unordered_unique!(g.paths_requiring(edge(&g, -1, -3, Eq)), []); + assert_eq_unordered_unique!(g.paths_requiring(edge(&g, -3, -3, Eq)), []); + + g.add_edge(edge(&g, 0, 1, Eq)); + assert_eq_unordered_unique!(g.paths_requiring(edge(&g, 1, 0, Eq)), [path(&g, 1, 0, Eq)]); + + assert_eq_unordered_unique!( + g.paths_requiring(edge(&g, 1, 0, Neq)), + [path(&g, 1, 0, Neq), path(&g, 0, 0, Neq), path(&g, 1, 1, Neq)] + ); + + g.add_edge(edge(&g, 2, 3, Neq)); + assert_eq_unordered_unique!( + g.paths_requiring(edge(&g, 3, 2, Eq)), + [path(&g, 3, 2, Eq), path(&g, 2, 2, Neq), path(&g, 3, 3, Neq)] + ); + } + + #[test] + fn test_paths_requiring() { + let mut g = instance1(); + assert_eq_unordered_unique!(g.paths_requiring(edge(&g, 0, 1, Eq)), []); + assert_eq_unordered_unique!(g.paths_requiring(edge(&g, 0, 1, Neq)), []); + assert_eq_unordered_unique!( + g.paths_requiring(edge(&g, 1, 2, Eq)), + [ + path(&g, 1, 2, Eq), + path(&g, 0, 2, Eq), + path(&g, 0, 4, Neq), + path(&g, 1, 4, Neq), + path(&g, 5, 2, Eq), + path(&g, 5, 4, Neq), + path(&g, 6, 2, Neq) + ] + ); + assert_eq_unordered_unique!( + g.paths_requiring(edge(&g, 2, 1, Eq)), + [ + path(&g, 2, 1, Eq), + path(&g, 2, 2, Neq), + path(&g, 2, 5, Eq), + path(&g, 2, 6, Eq), + path(&g, 2, 0, Eq), + path(&g, 2, 0, Neq), + path(&g, 2, 3, Eq), + path(&g, 2, 1, Neq), + path(&g, 2, 3, Neq), + path(&g, 2, 5, Neq), + path(&g, 2, 6, Neq), + ] + ); + g.insert_node(Node::Val(7)); + g.add_edge(edge(&g, 4, 7, Eq)); + assert_eq_unordered_unique!( + g.paths_requiring(edge(&g, 7, 4, Neq)), + [path(&g, 7, 4, Neq), path(&g, 7, 7, Neq), path(&g, 4, 4, Neq)] + ); + } +} diff --git a/solver/src/reasoners/eq_alt/graph/node_store.rs b/solver/src/reasoners/eq_alt/graph/node_store.rs new file mode 100644 index 000000000..b9349c915 --- /dev/null +++ b/solver/src/reasoners/eq_alt/graph/node_store.rs @@ -0,0 +1,362 @@ +use hashbrown::HashMap; + +use crate::{ + backtrack::{Backtrack, Trail}, + collections::ref_store::RefVec, + create_ref_type, + reasoners::eq_alt::node::Node, + transitive_conversion, +}; +use std::{cell::RefCell, fmt::Debug}; + +use super::NodeId; + +create_ref_type!(GroupId); +// Commenting these lines allows us to check where nodes are treated like groups and vice versa +transitive_conversion!(NodeId, u32, GroupId); +transitive_conversion!(GroupId, u32, NodeId); + +impl Debug for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Node {}", self.to_u32()) + } +} + +impl Debug for GroupId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Group {}", self.to_u32()) + } +} + +#[derive(Clone, Debug)] +struct Relations { + parent: Option, + next_sibling: Option, + previous_sibling: Option, + first_child: Option, +} + +impl Relations { + const DETACHED: Relations = Relations { + parent: None, + next_sibling: None, + previous_sibling: None, + first_child: None, + }; +} + +impl Default for Relations { + fn default() -> Self { + Self::DETACHED + } +} + +/// NodeStore is a backtrackable Id => Node map with Union-Find and path-flattening +#[derive(Clone, Default)] +pub struct NodeStore { + /// Maps NodeId to Node, doesn't support arbitrary removal! + nodes: RefVec, + /// Maps Node to NodeId, for interfacing with the graph + rev_nodes: HashMap, + /// Relations between elements of a group of nodes + group_relations: RefCell>, + trail: RefCell>, + path: RefCell>, +} + +#[allow(unused)] +impl NodeStore { + pub fn new() -> NodeStore { + Default::default() + } + + pub fn insert_node(&mut self, node: Node) -> NodeId { + debug_assert!(!self.rev_nodes.contains_key(&node)); + self.trail.borrow_mut().push(Event::Added); + let id = self.nodes.push(node); + self.rev_nodes.insert(node, id); + self.group_relations.borrow_mut().push(Default::default()); + id + } + + pub fn get_id(&self, node: &Node) -> Option { + self.rev_nodes.get(node).copied() + } + + pub fn get_node(&self, id: NodeId) -> Node { + self.nodes[id] + } + + pub fn merge_nodes(&mut self, child: NodeId, parent: NodeId) { + let child = self.get_group_id(child); + let parent = self.get_group_id(parent); + self.merge(child, parent); + } + + pub fn merge(&mut self, child: GroupId, parent: GroupId) { + debug_assert_eq!(child, self.get_group_id(child.into())); + debug_assert_eq!(parent, self.get_group_id(parent.into())); + if child != parent { + self.set_new_parent(child.into(), parent.into()); + } + } + + fn set_new_parent(&mut self, id: NodeId, parent_id: NodeId) { + debug_assert_ne!(id, parent_id); + // Ensure child has no relations or no parent + debug_assert!(self.group_relations.borrow()[id].parent.is_none()); + self.reparent(id, parent_id); + } + + fn reparent(&self, id: NodeId, parent_id: NodeId) { + debug_assert_ne!(id, parent_id); + // Get info about node's old status + let old_relations = { self.group_relations.borrow()[id].clone() }; + self.trail.borrow_mut().push(Event::ParentChanged { + id, + old_parent_id: old_relations.parent, + old_previous_sibling_id: old_relations.previous_sibling, + old_next_sibling_id: old_relations.next_sibling, + }); + + let mut group_relations_mut = self.group_relations.borrow_mut(); + + // If first child, set next sibling as first child + if let Some(old_parent) = old_relations.parent { + if old_relations.previous_sibling.is_none() { + group_relations_mut[old_parent].first_child = old_relations.next_sibling; + } + } + + // Join siblings together + if let Some(old_previous_sibling) = old_relations.previous_sibling { + group_relations_mut[old_previous_sibling].next_sibling = old_relations.next_sibling; + } + if let Some(old_next_sibling) = old_relations.next_sibling { + group_relations_mut[old_next_sibling].previous_sibling = old_relations.previous_sibling; + } + + // Set node as first child of new parent + let parent_relations = &mut group_relations_mut[parent_id]; + let first_sibling = parent_relations.first_child; + parent_relations.first_child = Some(id); + + // Setup node + let new_relations = &mut group_relations_mut[id]; + new_relations.previous_sibling = None; + new_relations.next_sibling = first_sibling; + new_relations.parent = Some(parent_id); + + if let Some(new_next_sibling) = first_sibling { + group_relations_mut[new_next_sibling].previous_sibling = Some(id); + } + } + + pub fn get_group_id(&self, mut id: NodeId) -> GroupId { + // Get the path from id to rep (inclusive) + let mut path = self.path.borrow_mut(); + path.clear(); + path.push(id); + while let Some(parent_id) = self.group_relations.borrow()[id].parent { + id = parent_id; + path.push(id); + } + // The rep is the last element + let rep_id = path.pop().unwrap(); + + // The last element doesn't need reparenting + path.pop(); + + for child_id in path.iter() { + self.reparent(*child_id, rep_id); + } + rep_id.into() + } + + pub fn get_group(&self, id: GroupId) -> Vec { + debug_assert_eq!(id, self.get_group_id(id.into())); + + let mut res = vec![]; + + // Depth first traversal using first_child and next_sibling + let mut stack = vec![id.into()]; + while let Some(n) = stack.pop() { + // Visit element in stack + res.push(n); + // Starting from first child + let Some(first_child) = self.group_relations.borrow()[n].first_child else { + continue; + }; + stack.push(first_child); + let gr = self.group_relations.borrow(); + let mut current_relations = &gr[first_child]; + while let Some(next_child) = current_relations.next_sibling { + stack.push(next_child); + current_relations = &gr[next_child]; + } + } + res + } + + pub fn get_group_nodes(&self, id: GroupId) -> Vec { + self.get_group(id).into_iter().map(|id| self.get_node(id)).collect() + } + + pub fn groups(&self) -> Vec { + let relations = self.group_relations.borrow(); + (0..relations.len()) + .filter_map(|i| (relations[i.into()].parent.is_none()).then_some(i.into())) + .collect() + } + + pub fn count_groups(&self) -> usize { + self.groups().len() + } + + pub fn len(&self) -> usize { + self.nodes.len() + } + + pub fn nodes(&self) -> Vec { + let relations = self.group_relations.borrow(); + (0..relations.len()).map(|i| i.into()).collect() + } +} + +// impl Default for NodeStore { +// fn default() -> Self { +// Self::new() +// } +// } + +#[derive(Clone)] +enum Event { + Added, + ParentChanged { + id: NodeId, + old_parent_id: Option, + old_previous_sibling_id: Option, + old_next_sibling_id: Option, + }, +} + +impl Backtrack for NodeStore { + fn save_state(&mut self) -> crate::backtrack::DecLvl { + self.trail.borrow_mut().save_state() + } + + fn num_saved(&self) -> u32 { + self.trail.borrow_mut().num_saved() + } + + fn restore_last(&mut self) { + use Event::*; + self.trail.borrow_mut().restore_last_with(|e| match e { + Added => { + let node = self.nodes.pop().unwrap(); + self.rev_nodes.remove(&node); + self.group_relations.borrow_mut().pop().unwrap(); + } + ParentChanged { + id, + old_parent_id, + old_previous_sibling_id, + old_next_sibling_id, + } => { + // NOTE: In this block, "new" refers to the state after the event happened, old before. + + // INVARIANT: Child is first child of it's current parent + let new_relations = { self.group_relations.borrow()[id].clone() }; + debug_assert_eq!(new_relations.previous_sibling, None); + + let mut group_relations_mut = self.group_relations.borrow_mut(); + + if let Some(new_next_sibling) = new_relations.next_sibling { + group_relations_mut[new_next_sibling].previous_sibling = None; + } + + // Set new parent's first child to new next sibling + group_relations_mut[new_relations.parent.unwrap()].first_child = new_relations.next_sibling; + + let mut_relations = &mut group_relations_mut[id]; + mut_relations.parent = old_parent_id; + mut_relations.previous_sibling = old_previous_sibling_id; + mut_relations.next_sibling = old_next_sibling_id; + + if let Some(old_previous_sibling) = old_previous_sibling_id { + group_relations_mut[old_previous_sibling].next_sibling = Some(id); + } else if let Some(old_parent) = old_parent_id { + group_relations_mut[old_parent].first_child = Some(id); + } + if let Some(old_next_sibling) = old_next_sibling_id { + group_relations_mut[old_next_sibling].previous_sibling = Some(id); + } + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_store() { + use std::collections::HashSet; + use Node::*; + + let mut ns = NodeStore::new(); + ns.save_state(); + + // Insert three distinct nodes + let n0 = ns.insert_node(Val(0)); + let n1 = ns.insert_node(Val(1)); + let n2 = ns.insert_node(Val(2)); + + assert_ne!(ns.get_group_id(n0), ns.get_group_id(n1)); + assert_ne!(ns.get_group_id(n1), ns.get_group_id(n2)); + + // Merge n0 and n1, then n1 and n2 => all should be in one group + ns.merge_nodes(n0, n1); + ns.merge_nodes(n1, n2); + let rep = ns.get_group_id(n0); + assert_eq!(rep, ns.get_group_id(n2)); + assert_eq!( + ns.get_group(ns.get_group_id(n1)).into_iter().collect::>(), + [n0, n1, n2].into() + ); + + // Merge same nodes again to check idempotency + ns.merge_nodes(n0, n2); + assert_eq!(ns.get_group_id(n0), rep); + + // Add a new node and ensure it's separate + let n3 = ns.insert_node(Val(3)); + assert_ne!(ns.get_group_id(n3), rep); + + ns.save_state(); + + // Merge into existing group + ns.merge_nodes(n2, n3); + assert_eq!( + ns.get_group(ns.get_group_id(n3)).into_iter().collect::>(), + [n0, n1, n2, n3].into() + ); + + // Restore to state before n3 was merged + ns.restore_last(); + assert_ne!(ns.get_group_id(n3), rep); + assert_eq!( + ns.get_group(ns.get_group_id(n2)).into_iter().collect::>(), + [n0, n1, n2].into() + ); + + // Restore to initial state + ns.restore_last(); + assert!(ns.get_id(&Val(0)).is_none()); + assert!(ns.get_id(&Val(1)).is_none()); + + // Attempt to query a non-existent node + assert!(ns.get_id(&Val(99)).is_none()); + } +} diff --git a/solver/src/reasoners/eq_alt/graph/transforms.rs b/solver/src/reasoners/eq_alt/graph/transforms.rs new file mode 100644 index 000000000..cdc20b070 --- /dev/null +++ b/solver/src/reasoners/eq_alt/graph/transforms.rs @@ -0,0 +1,186 @@ +use crate::reasoners::eq_alt::relation::EqRelation; + +use super::{ + traversal::{self}, + Edge, EqAdjList, NodeId, Path, +}; + +// Implementations of generic edge for concrete edge type +impl traversal::Edge for Edge { + fn target(&self) -> NodeId { + self.target + } + + fn source(&self) -> NodeId { + self.source + } +} + +// Implementation of generic graph for concrete graph +impl traversal::Graph for &EqAdjList { + fn outgoing(&self, node: NodeId) -> impl Iterator { + self.iter_edges(node).cloned() + } +} + +// Node with associated relation type +#[derive(Debug, Copy, Clone, PartialEq, Hash, Eq)] +pub struct EqNode(pub NodeId, pub EqRelation); +impl EqNode { + /// Returns EqNode with relation = + pub fn new(source: NodeId) -> Self { + Self(source, EqRelation::Eq) + } + + pub fn negate(self) -> Self { + Self(self.0, !self.1) + } + + pub fn path_to(&self, other: &EqNode) -> Option { + Some(Path::new(self.0, other.0, (self.1 + other.1)?)) + } +} + +// Node trait implementation for Eq Node +// Relation gets first bit, N is shifted to the left by one + +impl From for EqNode { + fn from(value: usize) -> Self { + let r = if value & 1 != 0 { + EqRelation::Eq + } else { + EqRelation::Neq + }; + Self((value >> 1).into(), r) + } +} + +impl From for usize { + fn from(value: EqNode) -> Self { + let shift = 1; + let v = match value.1 { + EqRelation::Eq => 1_usize, + EqRelation::Neq => 0_usize, + }; + v | usize::from(value.0) << shift + } +} + +/// EqEdge type that goes with EqNode for graph traversal. +/// +/// Second field is the relation of the target node +/// (Hence the - in source) +#[derive(Debug, Clone)] +pub struct EqEdge(pub Edge, EqRelation); + +impl traversal::Edge for EqEdge { + fn target(&self) -> EqNode { + EqNode(self.0.target, self.1) + } + + fn source(&self) -> EqNode { + EqNode(self.0.source, (self.1 - self.0.relation).unwrap()) + } +} + +/// Filters the graph to only include edges with equality relation. +/// +/// Commonly used when looking for nodes which are equal to the source +pub struct EqFilter>(G); + +impl> traversal::Graph for EqFilter { + fn outgoing(&self, node: NodeId) -> impl Iterator { + self.0.outgoing(node).filter(|e| e.relation == EqRelation::Eq) + } +} + +/// Extension trait used to add the eq method to implementations of Graph +pub trait EqExt> { + /// Filters the graph to only include edges with equality relation. + /// + /// Commonly used when looking for nodes which are equal to the source + fn eq(self) -> EqFilter; +} +impl EqExt for G +where + G: traversal::Graph, +{ + fn eq(self) -> EqFilter { + EqFilter(self) + } +} + +/// Transform the graph in order to traverse it following equality's transitivity laws. +/// +/// Modifies the graph so that each node has two copies: One with Eq relation, and one with Neq relation. +/// +/// Adapts edges so that a -=> b && b -!=-> c, a -!=-> c and so on. +pub struct EqNeqFilter>(G); + +impl> traversal::Graph for EqNeqFilter { + fn outgoing(&self, node: EqNode) -> impl Iterator { + self.0.outgoing(node.0).filter_map(move |e| { + let r = (e.relation + node.1)?; + Some(EqEdge(e, r)) + }) + } +} + +pub trait EqNeqExt> { + /// Transform the graph in order to traverse it following equality's transitivity laws. + /// + /// Modifies the graph so that each node has two copies: One with Eq relation, and one with Neq relation. + /// + /// Adapts edges so that a -=> b && b -!=-> c, a -!=-> c and so on. + fn eq_neq(self) -> EqNeqFilter; +} +impl EqNeqExt for G +where + G: traversal::Graph, +{ + fn eq_neq(self) -> EqNeqFilter { + EqNeqFilter(self) + } +} + +/// Filter the graph according to a closure. +pub struct FilteredGraph(G, F, std::marker::PhantomData<(N, E)>) +where + N: traversal::Node, + E: traversal::Edge, + G: traversal::Graph, + F: Fn(N, &E) -> bool; + +impl traversal::Graph for FilteredGraph +where + N: traversal::Node, + E: traversal::Edge, + G: traversal::Graph, + F: Fn(N, &E) -> bool, +{ + fn outgoing(&self, node: N) -> impl Iterator { + self.0.outgoing(node).filter(move |e| self.1(node, e)) + } +} + +pub trait FilterExt +where + N: traversal::Node, + E: traversal::Edge, + G: traversal::Graph, + F: Fn(N, &E) -> bool, +{ + /// Filter the graph according to a closure. + fn filter(self, f: F) -> FilteredGraph; +} +impl FilterExt for G +where + N: traversal::Node, + E: traversal::Edge, + G: traversal::Graph, + F: Fn(N, &E) -> bool, +{ + fn filter(self, f: F) -> FilteredGraph { + FilteredGraph(self, f, std::marker::PhantomData {}) + } +} diff --git a/solver/src/reasoners/eq_alt/graph/traversal.rs b/solver/src/reasoners/eq_alt/graph/traversal.rs new file mode 100644 index 000000000..12f8c553c --- /dev/null +++ b/solver/src/reasoners/eq_alt/graph/traversal.rs @@ -0,0 +1,278 @@ +use std::collections::VecDeque; + +use crate::collections::{ + ref_store::{IterableRefMap, Ref}, + set::IterableRefSet, +}; + +pub trait Node: Ref {} +impl Node for T {} + +/// A trait representing a generic directed edge with a source and target. +pub trait Edge: Clone { + fn target(&self) -> N; + fn source(&self) -> N; +} + +/// A trait representing a generic directed Graph. +pub trait Graph> { + /// Get outgoing edges from the node. + fn outgoing(&self, node: N) -> impl Iterator; + + /// Traverse the graph (depth first) from a given `source`. This method return a GraphTraversal object which implements Iterator. + /// + /// Scratch contains the large data structures used by the graph traversal algorithm. Useful to reuse memory. + /// `&mut Default::default()` can be used if performance is not critical. + fn traverse_dfs<'a>( + self, + source: N, + scratch: &'a mut Scratch>, + ) -> GraphTraversal<'a, N, E, Self, Vec> + where + Self: Sized, + { + GraphTraversal::new(self, source, scratch) + } + + /// Traverse the graph (breadth first) from a given source. This method return a GraphTraversal object which implements Iterator. + /// + /// Scratch contains the large data structures used by the graph traversal algorithm. Useful to reuse memory. + /// `&mut default::default()` can used if performance is not critical. + fn traverse_bfs<'a>( + self, + source: N, + scratch: &'a mut Scratch>, + ) -> GraphTraversal<'a, N, E, Self, VecDeque> + where + Self: Sized, + { + GraphTraversal::new(self, source, scratch) + } + + /// Get the set of nodes which can be reached from the source. + /// + /// See traverse for details about scratch. + fn reachable<'a>(self, source: N, scratch: &'a mut Scratch>) -> Visited<'a, N> + where + Self: Sized + 'a, + N: 'a, + E: 'a, + { + let mut t = GraphTraversal::new(self, source, scratch); + for _ in t.by_ref() {} + scratch.visited() + } +} + +pub trait Frontier { + fn push(&mut self, value: N); + + fn pop(&mut self) -> Option; + + fn extend(&mut self, values: impl IntoIterator); + + fn clear(&mut self); +} + +impl Frontier for Vec { + fn push(&mut self, value: N) { + self.push(value); + } + + fn pop(&mut self) -> Option { + self.pop() + } + + fn extend(&mut self, values: impl IntoIterator) { + Extend::extend(self, values) + } + + fn clear(&mut self) { + self.clear() + } +} + +impl Frontier for VecDeque { + fn push(&mut self, value: N) { + self.push_back(value); + } + + fn pop(&mut self) -> Option { + self.pop_front() + } + + fn extend(&mut self, values: impl IntoIterator) { + Extend::extend(self, values); + } + + fn clear(&mut self) { + self.clear() + } +} + +/// A data structure that can be passed to GraphTraversal in order to record parents of visited nodes. +/// This allows for path queries after traversal. +/// +/// Call record_paths on GraphTraversal with this struct. +pub struct PathStore>(IterableRefMap); + +impl> PathStore { + pub fn new() -> Self { + Self(Default::default()) + } + + pub fn get_path(&self, mut target: N) -> impl Iterator + use<'_, N, E> { + std::iter::from_fn(move || { + self.0.get(target).map(|e| { + target = e.source(); + e.clone() + }) + }) + } +} + +/// Scratch contains the large data structures used by the graph traversal algorithm. Useful to reuse memory. +/// +/// In order to avoid having to deal with generics when reusing an instance, we use usize instead of N: Into\ + From\. +/// We therefore need structs to access these data structures with N. +#[derive(Default)] +pub struct Scratch> { + frontier: F, + visited: IterableRefSet, +} + +/// Used to access Scratch.stack as if it were `Vec` +struct FrontierMut<'a, N: Into + From, F: Frontier>(&'a mut F, std::marker::PhantomData); + +impl<'a, N: Into + From, F: Frontier> FrontierMut<'a, N, F> { + fn new(s: &'a mut F) -> Self { + Self(s, std::marker::PhantomData {}) + } + + fn push(&mut self, n: N) { + self.0.push(n.into()) + } + + fn pop(&mut self) -> Option { + self.0.pop().map(Into::into) + } + + fn extend(&mut self, iter: impl IntoIterator) { + self.0.extend(iter.into_iter().map(Into::into)) + } +} + +/// Used to access Scratch.visited as if it were `IterableRefSet` +pub struct VisitedMut<'a, N: Into + From>(&'a mut IterableRefSet, std::marker::PhantomData); + +impl<'a, N: Into + From> VisitedMut<'a, N> { + fn new(v: &'a mut IterableRefSet) -> Self { + Self(v, std::marker::PhantomData {}) + } + + pub fn contains(&mut self, n: N) -> bool { + self.0.contains(n.into()) + } + + pub fn insert(&mut self, n: N) { + self.0.insert(n.into()) + } +} + +/// Used to access Scratch.visited as if it were `IterableRefSet` +pub struct Visited<'a, N: Into + From>(&'a IterableRefSet, std::marker::PhantomData); +impl<'a, N: Into + From> Visited<'a, N> { + fn new(v: &'a IterableRefSet) -> Self { + Self(v, std::marker::PhantomData {}) + } + + pub fn contains(&self, n: N) -> bool { + self.0.contains(n.into()) + } +} + +impl> Scratch { + fn stack<'a, N: Into + From>(&'a mut self) -> FrontierMut<'a, N, F> { + FrontierMut::new(&mut self.frontier) + } + + fn visited_mut<'a, N: Into + From>(&'a mut self) -> VisitedMut<'a, N> { + VisitedMut::new(&mut self.visited) + } + + fn visited<'a, N: Into + From>(&'a self) -> Visited<'a, N> { + Visited::new(&self.visited) + } + + fn clear(&mut self) { + self.frontier.clear(); + self.visited.clear(); + } +} + +/// Struct for traversing a Graph with DFS or BFS. +/// Implements iterator. +pub struct GraphTraversal<'a, N: Node, E: Edge, G: Graph, F: Frontier> { + graph: G, + scratch: &'a mut Scratch, + parents: Option<&'a mut PathStore>, +} + +impl<'a, N: Node, E: Edge, G: Graph, F: Frontier> GraphTraversal<'a, N, E, G, F> { + fn new(graph: G, source: N, scratch: &'a mut Scratch) -> Self { + scratch.clear(); + scratch.stack().push(source); + GraphTraversal { + graph, + scratch, + parents: None, + } + } + + /// Record paths taken during traversal to PathStore, allowing for path from source to visited node queries. + pub fn record_paths(mut self, path_store: &'a mut PathStore) -> Self { + // TODO: We should make this safe by introducing a new type for iteration + debug_assert!(self.parents.is_none()); + debug_assert!(self.scratch.visited.is_empty()); + self.parents = Some(path_store); + self + } + + pub fn visited(&self) -> Visited<'_, N> { + self.scratch.visited() + } +} + +impl, G: Graph, F: Frontier> Iterator for GraphTraversal<'_, N, E, G, F> { + type Item = N; + + fn next(&mut self) -> Option { + // Get the next unvisited node + let mut node = self.scratch.stack().pop()?; + while self.scratch.visited_mut().contains(node) { + node = self.scratch.stack().pop()?; + } + + // Mark as visited + self.scratch.visited_mut().insert(node); + + let mut stack = FrontierMut::new(&mut self.scratch.frontier); + let visited = Visited::new(&self.scratch.visited); + + // Get all (unvisited) nodes that can be reached through an outgoing edge + let new_nodes = self.graph.outgoing(node).filter_map(|e| { + let target = e.target(); + if !visited.contains(target) { + if let Some(parents) = self.parents.as_mut() { + parents.0.insert(target, e); + } + Some(target) + } else { + None + } + }); + stack.extend(new_nodes); + + Some(node) + } +} diff --git a/solver/src/reasoners/eq_alt/mod.rs b/solver/src/reasoners/eq_alt/mod.rs new file mode 100644 index 000000000..e784e6b5a --- /dev/null +++ b/solver/src/reasoners/eq_alt/mod.rs @@ -0,0 +1,210 @@ +//! This module exports an alternate propagator for equality logic. +//! +//! Since DenseEqTheory has O(n^2) space complexity it tends to have performance issues on larger problems. +//! This alternative has much lower memory use on sparse problems, and can make stronger inferences than just the STN +//! +//! Currently, this propagator is intended to be used in conjunction with the StnTheory. +//! Each l => x = y constraint should be posted as l => x >= y and l => x <= y, +//! and each l => x != y constraint should be posted as l => x > y or l => x < y in the STN. +//! This is because AltEqTheory does not do bound propagation yet +//! (When a integer variable's bounds are updated, no propagation occurs). +//! Stn is therefore ideally used in "bounds" propagation mode ("edges" is redundant) with this propagator. + +// TODO: Implement bound propagation for this theory. + +mod constraints; +mod graph; +mod node; +mod relation; +mod theory; + +pub use theory::AltEqTheory; + +#[cfg(test)] +mod tests { + use std::fmt::Display; + + use itertools::Itertools; + use rand::{rngs::SmallRng, seq::IteratorRandom, Rng, SeedableRng}; + + use crate::{ + core::{ + state::{Cause, Domains}, + IntCst, Lit, VarRef, + }, + model::{ + lang::{ + expr::{and, eq, geq, gt, leq, lt, neq, or}, + IVar, + }, + Model, + }, + solver::{search::random::RandomChoice, Solver}, + }; + + use super::relation::EqRelation; + + struct Problem { + domains: Domains, + constraints: Vec<(VarRef, VarRef, EqRelation, bool, bool)>, + } + + const VARS_PER_PROBLEM: usize = 20; + + fn generate_problem(rng: &mut SmallRng) -> Problem { + // Calibrated for approximately equal number of solvable and unsolvable problems + let sparsity = 0.5; + let neq_probability = 0.5; + let full_reif_probability = 0.5; + let enforce_probability = 0.5; + let max_scopes = 5; + + let mut domains = Domains::new(); + + let num_scopes = rng.gen_range(1..max_scopes); + + let mut scopes = vec![Lit::TRUE]; + for i in 1..num_scopes { + scopes.push(domains.new_presence_literal(scopes[i - 1])); + } + + // Lit::TRUE, Lit::FALSE, and scopes other than TRUE + let var_offset = num_scopes - 1 + 2; + + for i in var_offset..VARS_PER_PROBLEM + var_offset { + assert_eq!( + VarRef::from(i), + domains.new_optional_var(0, VARS_PER_PROBLEM as IntCst - 1, *scopes.iter().choose(rng).unwrap()) + ); + } + + #[allow(clippy::filter_map_bool_then)] // Avoids double borrowing rng + let constraints = (var_offset..VARS_PER_PROBLEM + var_offset) + .tuple_combinations() + .filter_map(|(a, b)| { + rng.gen_bool(sparsity).then(|| { + ( + a.into(), + b.into(), + if rng.gen_bool(neq_probability) { + EqRelation::Neq + } else { + EqRelation::Eq + }, + rng.gen_bool(full_reif_probability), + rng.gen_bool(enforce_probability), + ) + }) + }) + .collect_vec(); + Problem { domains, constraints } + } + + #[derive(Debug, Hash, PartialEq, Eq, Clone)] + enum Label { + ReifLiteral(VarRef, VarRef), + } + + impl Display for Label { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "{:?}", self) + } + } + + fn model_with_eq(problem: &Problem) -> Model