1+ use std:: collections:: { HashMap , HashSet } ;
2+
3+ /// Custom error type for graph operations.
4+ #[ derive( Debug ) ]
5+ pub enum GraphError {
6+ NodeNotFound ( String ) ,
7+ InvalidOperation ( String ) ,
8+ }
9+
10+ impl std:: fmt:: Display for GraphError {
11+ fn fmt ( & self , f : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
12+ match self {
13+ GraphError :: NodeNotFound ( node) => write ! ( f, "Node '{}' not found in the graph" , node) ,
14+ GraphError :: InvalidOperation ( msg) => write ! ( f, "Invalid operation: {}" , msg) ,
15+ }
16+ }
17+ }
18+
19+ impl std:: error:: Error for GraphError { }
20+
21+ /// Trait for handling roles in graphs (similar to Python mixin).
22+ pub trait GraphRoles : Clone {
23+ /// Check if a node exists in the graph.
24+ fn has_node ( & self , node : & str ) -> bool ;
25+
26+ /// Get immutable reference to the roles map.
27+ fn get_roles_map ( & self ) -> & HashMap < String , HashSet < String > > ;
28+
29+ /// Get mutable reference to the roles map.
30+ fn get_roles_map_mut ( & mut self ) -> & mut HashMap < String , HashSet < String > > ;
31+
32+ /// Get nodes with a specific role.
33+ fn get_role ( & self , role : & str ) -> Vec < String > {
34+ self . get_roles_map ( )
35+ . get ( role)
36+ . cloned ( )
37+ . unwrap_or_default ( )
38+ . into_iter ( )
39+ . collect ( )
40+ }
41+
42+ /// Get list of all roles.
43+ fn get_roles ( & self ) -> Vec < String > {
44+ self . get_roles_map ( ) . keys ( ) . cloned ( ) . collect ( )
45+ }
46+
47+ /// Get dict of roles to nodes.
48+ fn get_role_dict ( & self ) -> HashMap < String , Vec < String > > {
49+ self . get_roles_map ( )
50+ . iter ( )
51+ . map ( |( k, v) | ( k. clone ( ) , v. iter ( ) . cloned ( ) . collect ( ) ) )
52+ . collect ( )
53+ }
54+
55+ /// Check if a role exists and has nodes.
56+ fn has_role ( & self , role : & str ) -> bool {
57+ self . get_roles_map ( )
58+ . get ( role)
59+ . map ( |set| !set. is_empty ( ) )
60+ . unwrap_or ( false )
61+ }
62+
63+ /// Assign role to variables. Modifies in place if `inplace=true`, otherwise returns a new graph.
64+ fn with_role ( & mut self , role : String , variables : Vec < String > , inplace : bool ) -> Result < Self , GraphError > {
65+ if inplace {
66+ // Modify self directly
67+ for var in & variables {
68+ if !self . has_node ( var) {
69+ return Err ( GraphError :: NodeNotFound ( var. clone ( ) ) ) ;
70+ }
71+ }
72+ let roles_map = self . get_roles_map_mut ( ) ;
73+ let entry = roles_map. entry ( role) . or_insert ( HashSet :: new ( ) ) ;
74+ for var in variables {
75+ entry. insert ( var) ;
76+ }
77+ Ok ( self . clone ( ) ) // Return self.clone() for consistency, but self is modified
78+ } else {
79+ // Create and modify a new graph
80+ let mut new_graph = self . clone ( ) ;
81+ for var in & variables {
82+ if !new_graph. has_node ( var) {
83+ return Err ( GraphError :: NodeNotFound ( var. clone ( ) ) ) ;
84+ }
85+ }
86+ let roles_map = new_graph. get_roles_map_mut ( ) ;
87+ let entry = roles_map. entry ( role) . or_insert ( HashSet :: new ( ) ) ;
88+ for var in variables {
89+ entry. insert ( var) ;
90+ }
91+ Ok ( new_graph)
92+ }
93+ }
94+
95+ /// Remove role from variables (or all if None). Modifies in place if `inplace=true`, otherwise returns a new graph.
96+ fn without_role ( & mut self , role : & str , variables : Option < Vec < String > > , inplace : bool ) -> Self {
97+ if inplace {
98+ if let Some ( set) = self . get_roles_map_mut ( ) . get_mut ( role) {
99+ if let Some ( vars) = variables {
100+ for var in vars {
101+ set. remove ( & var) ;
102+ }
103+ } else {
104+ set. clear ( ) ;
105+ }
106+ }
107+ self . clone ( ) // Return self.clone() for consistency
108+ } else {
109+ let mut new_graph = self . clone ( ) ;
110+ if let Some ( set) = new_graph. get_roles_map_mut ( ) . get_mut ( role) {
111+ if let Some ( vars) = variables {
112+ for var in vars {
113+ set. remove ( & var) ;
114+ }
115+ } else {
116+ set. clear ( ) ;
117+ }
118+ }
119+ new_graph
120+ }
121+ }
122+
123+ /// Validate causal structure (has exposure and outcome).
124+ fn is_valid_causal_structure ( & self ) -> Result < bool , GraphError > {
125+ let has_exposure = self . has_role ( "exposure" ) ;
126+ let has_outcome = self . has_role ( "outcome" ) ;
127+ if !has_exposure || !has_outcome {
128+ let mut problems = Vec :: new ( ) ;
129+ if !has_exposure {
130+ problems. push ( "no 'exposure' role was defined" ) ;
131+ }
132+ if !has_outcome {
133+ problems. push ( "no 'outcome' role was defined" ) ;
134+ }
135+ return Err ( GraphError :: InvalidOperation ( problems. join ( ", and " ) ) ) ;
136+ }
137+ Ok ( true )
138+ }
139+ }
0 commit comments