Skip to content

Commit 45bb0a7

Browse files
authored
Core impl [Graph Role + base identification] (#27)
* Graph Role core impl * base identification class impl
1 parent a9d08e8 commit 45bb0a7

File tree

8 files changed

+592
-3
lines changed

8 files changed

+592
-3
lines changed

rust_core/src/dag.rs

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use petgraph::Direction;
22
use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex};
33
use std::collections::{HashMap, HashSet, VecDeque};
4+
use crate::graph_role::{GraphError, GraphRoles};
5+
use std::hash::{Hash, Hasher};
6+
use crate::graph::Graph;
47

58
/// Directed Acyclic Graph (DAG) with optional latent variables.
69
///
@@ -24,6 +27,89 @@ pub struct RustDAG {
2427
pub node_map: HashMap<String, NodeIndex>,
2528
pub reverse_node_map: HashMap<NodeIndex, String>,
2629
pub latents: HashSet<String>,
30+
pub roles: HashMap<String, HashSet<String>>, // New: role -> set of nodes
31+
}
32+
33+
impl PartialEq for RustDAG {
34+
fn eq(&self, other: &Self) -> bool {
35+
// Compare nodes
36+
let self_nodes: HashSet<&String> = self.node_map.keys().collect();
37+
let other_nodes: HashSet<&String> = other.node_map.keys().collect();
38+
if self_nodes != other_nodes {
39+
return false;
40+
}
41+
42+
// Compare edges
43+
let self_edges: HashSet<(String, String)> = self.edges().into_iter().collect();
44+
let other_edges: HashSet<(String, String)> = other.edges().into_iter().collect();
45+
if self_edges != other_edges {
46+
return false;
47+
}
48+
49+
// Compare latents
50+
if self.latents != other.latents {
51+
return false;
52+
}
53+
54+
// Compare roles
55+
let mut self_roles: Vec<(String, Vec<String>)> = self
56+
.get_roles()
57+
.into_iter()
58+
.map(|role| {
59+
let mut nodes = self.get_role(&role);
60+
nodes.sort();
61+
(role, nodes)
62+
})
63+
.collect();
64+
self_roles.sort_by(|a, b| a.0.cmp(&b.0));
65+
66+
let mut other_roles: Vec<(String, Vec<String>)> = other
67+
.get_roles()
68+
.into_iter()
69+
.map(|role| {
70+
let mut nodes = other.get_role(&role);
71+
nodes.sort();
72+
(role, nodes)
73+
})
74+
.collect();
75+
other_roles.sort_by(|a, b| a.0.cmp(&b.0));
76+
77+
self_roles == other_roles
78+
}
79+
}
80+
81+
impl Eq for RustDAG {}
82+
83+
impl Hash for RustDAG {
84+
fn hash<H: Hasher>(&self, state: &mut H) {
85+
// Hash nodes
86+
let mut nodes: Vec<&String> = self.node_map.keys().collect();
87+
nodes.sort();
88+
nodes.hash(state);
89+
90+
// Hash edges
91+
let mut edges: Vec<(String, String)> = self.edges();
92+
edges.sort();
93+
edges.hash(state);
94+
95+
// Hash latents
96+
let mut latents: Vec<&String> = self.latents.iter().collect();
97+
latents.sort();
98+
latents.hash(state);
99+
100+
// Hash roles
101+
let mut roles: Vec<(String, Vec<String>)> = self
102+
.get_roles()
103+
.into_iter()
104+
.map(|role| {
105+
let mut nodes: Vec<String> = self.get_role(&role);
106+
nodes.sort();
107+
(role, nodes)
108+
})
109+
.collect();
110+
roles.sort_by(|a, b| a.0.cmp(&b.0));
111+
roles.hash(state);
112+
}
27113
}
28114

29115
impl RustDAG {
@@ -37,6 +123,7 @@ impl RustDAG {
37123
node_map: HashMap::new(),
38124
reverse_node_map: HashMap::new(),
39125
latents: HashSet::new(),
126+
roles: HashMap::new(),
40127
}
41128
}
42129

@@ -111,7 +198,6 @@ impl RustDAG {
111198
Ok(())
112199
}
113200

114-
115201
/// Add multiple directed edges.
116202
///
117203
/// # Parameters
@@ -593,3 +679,34 @@ impl RustDAG {
593679
self.graph.edge_count()
594680
}
595681
}
682+
683+
impl Graph for RustDAG {
684+
fn nodes(&self) -> Vec<String> {
685+
self.node_map.keys().cloned().collect()
686+
}
687+
688+
fn parents(&self, node: &str) -> Result<Vec<String>, GraphError> {
689+
self.get_parents(node)
690+
.map_err(|e| GraphError::NodeNotFound(e))
691+
}
692+
693+
fn ancestors(&self, nodes: Vec<String>) -> Result<HashSet<String>, GraphError> {
694+
self.get_ancestors_of(nodes)
695+
.map_err(|e| GraphError::NodeNotFound(e))
696+
}
697+
}
698+
699+
impl GraphRoles for RustDAG {
700+
fn has_node(&self, node: &str) -> bool {
701+
self.node_map.contains_key(node)
702+
}
703+
704+
fn get_roles_map(&self) -> &HashMap<String, HashSet<String>> {
705+
&self.roles
706+
}
707+
708+
fn get_roles_map_mut(&mut self) -> &mut HashMap<String, HashSet<String>> {
709+
&mut self.roles
710+
}
711+
}
712+

rust_core/src/graph.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use crate::graph_role::GraphError;
2+
use std::collections::HashSet;
3+
4+
/// Trait for core graph operations required by causal graphs.
5+
pub trait Graph {
6+
/// Get all nodes in the graph.
7+
fn nodes(&self) -> Vec<String>;
8+
9+
/// Get the parents of a node.
10+
fn parents(&self, node: &str) -> Result<Vec<String>, GraphError>;
11+
12+
/// Get the ancestors of a set of nodes (including the nodes themselves).
13+
fn ancestors(&self, nodes: Vec<String>) -> Result<HashSet<String>, GraphError>;
14+
}

rust_core/src/graph_role.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use std::collections::{HashMap, HashSet};
2+
3+
/// Custom error type for graph operations.
4+
#[derive(Debug)]
5+
pub enum GraphError {
6+
NodeNotFound(String),
7+
InvalidOperation(String),
8+
}
9+
10+
impl std::fmt::Display for GraphError {
11+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
12+
match self {
13+
GraphError::NodeNotFound(node) => write!(f, "Node '{}' not found in the graph", node),
14+
GraphError::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg),
15+
}
16+
}
17+
}
18+
19+
impl std::error::Error for GraphError {}
20+
21+
/// Trait for handling roles in graphs (similar to Python mixin).
22+
pub trait GraphRoles: Clone {
23+
/// Check if a node exists in the graph.
24+
fn has_node(&self, node: &str) -> bool;
25+
26+
/// Get immutable reference to the roles map.
27+
fn get_roles_map(&self) -> &HashMap<String, HashSet<String>>;
28+
29+
/// Get mutable reference to the roles map.
30+
fn get_roles_map_mut(&mut self) -> &mut HashMap<String, HashSet<String>>;
31+
32+
/// Get nodes with a specific role.
33+
fn get_role(&self, role: &str) -> Vec<String> {
34+
self.get_roles_map()
35+
.get(role)
36+
.cloned()
37+
.unwrap_or_default()
38+
.into_iter()
39+
.collect()
40+
}
41+
42+
/// Get list of all roles.
43+
fn get_roles(&self) -> Vec<String> {
44+
self.get_roles_map().keys().cloned().collect()
45+
}
46+
47+
/// Get dict of roles to nodes.
48+
fn get_role_dict(&self) -> HashMap<String, Vec<String>> {
49+
self.get_roles_map()
50+
.iter()
51+
.map(|(k, v)| (k.clone(), v.iter().cloned().collect()))
52+
.collect()
53+
}
54+
55+
/// Check if a role exists and has nodes.
56+
fn has_role(&self, role: &str) -> bool {
57+
self.get_roles_map()
58+
.get(role)
59+
.map(|set| !set.is_empty())
60+
.unwrap_or(false)
61+
}
62+
63+
/// Assign role to variables. Modifies in place if `inplace=true`, otherwise returns a new graph.
64+
fn with_role(&mut self, role: String, variables: Vec<String>, inplace: bool) -> Result<Self, GraphError> {
65+
if inplace {
66+
// Modify self directly
67+
for var in &variables {
68+
if !self.has_node(var) {
69+
return Err(GraphError::NodeNotFound(var.clone()));
70+
}
71+
}
72+
let roles_map = self.get_roles_map_mut();
73+
let entry = roles_map.entry(role).or_insert(HashSet::new());
74+
for var in variables {
75+
entry.insert(var);
76+
}
77+
Ok(self.clone()) // Return self.clone() for consistency, but self is modified
78+
} else {
79+
// Create and modify a new graph
80+
let mut new_graph = self.clone();
81+
for var in &variables {
82+
if !new_graph.has_node(var) {
83+
return Err(GraphError::NodeNotFound(var.clone()));
84+
}
85+
}
86+
let roles_map = new_graph.get_roles_map_mut();
87+
let entry = roles_map.entry(role).or_insert(HashSet::new());
88+
for var in variables {
89+
entry.insert(var);
90+
}
91+
Ok(new_graph)
92+
}
93+
}
94+
95+
/// Remove role from variables (or all if None). Modifies in place if `inplace=true`, otherwise returns a new graph.
96+
fn without_role(&mut self, role: &str, variables: Option<Vec<String>>, inplace: bool) -> Self {
97+
if inplace {
98+
if let Some(set) = self.get_roles_map_mut().get_mut(role) {
99+
if let Some(vars) = variables {
100+
for var in vars {
101+
set.remove(&var);
102+
}
103+
} else {
104+
set.clear();
105+
}
106+
}
107+
self.clone() // Return self.clone() for consistency
108+
} else {
109+
let mut new_graph = self.clone();
110+
if let Some(set) = new_graph.get_roles_map_mut().get_mut(role) {
111+
if let Some(vars) = variables {
112+
for var in vars {
113+
set.remove(&var);
114+
}
115+
} else {
116+
set.clear();
117+
}
118+
}
119+
new_graph
120+
}
121+
}
122+
123+
/// Validate causal structure (has exposure and outcome).
124+
fn is_valid_causal_structure(&self) -> Result<bool, GraphError> {
125+
let has_exposure = self.has_role("exposure");
126+
let has_outcome = self.has_role("outcome");
127+
if !has_exposure || !has_outcome {
128+
let mut problems = Vec::new();
129+
if !has_exposure {
130+
problems.push("no 'exposure' role was defined");
131+
}
132+
if !has_outcome {
133+
problems.push("no 'outcome' role was defined");
134+
}
135+
return Err(GraphError::InvalidOperation(problems.join(", and ")));
136+
}
137+
Ok(true)
138+
}
139+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use crate::graph::Graph;
2+
use crate::graph_role::{GraphError, GraphRoles};
3+
4+
/// Trait for causal identification algorithms, mirroring Python's BaseIdentification.
5+
pub trait BaseIdentification {
6+
/// Internal identification method to be implemented by specific algorithms.
7+
fn _identify<T: Graph + GraphRoles>(
8+
&self,
9+
causal_graph: &T,
10+
) -> Result<(T, bool), GraphError>;
11+
12+
/// Run the identification algorithm on a causal graph.
13+
fn identify<T: Graph + GraphRoles>(
14+
&self,
15+
causal_graph: &T,
16+
) -> Result<(T, bool), GraphError> {
17+
causal_graph.is_valid_causal_structure()?;
18+
self._identify(causal_graph)
19+
}
20+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pub mod base;
2+
3+
pub use base::BaseIdentification;

rust_core/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
// Re-export modules/structs from your core logic
22
pub mod dag;
33
pub mod independencies;
4+
pub mod identification;
45
pub mod pdag; // Add PDAG.rs later if needed
6+
pub mod graph_role;
7+
pub mod graph;
58

69
pub use dag::RustDAG;
710
pub use pdag::RustPDAG;

0 commit comments

Comments
 (0)