@@ -15,7 +15,7 @@ class InterventionMixin:
1515 nodes : NodeView
1616
1717 @abstractmethod
18- def add_edge (self , u , v , edge_type : Optional [ str ] ):
18+ def add_edge (self , u_of_edge , v_of_edge , edge_type = "all" , ** attr ):
1919 pass
2020
2121 @abstractmethod
@@ -82,6 +82,15 @@ def add_f_nodes_from(self, intervention_sets: List[Set[Node]]):
8282 for intervention_set in intervention_sets :
8383 self .add_f_node (intervention_set )
8484
85+ def set_f_node (self , f_node , targets : Optional [Set ] = None ):
86+ if f_node not in self .nodes :
87+ raise RuntimeError (f"{ f_node } is not a node in the existing graph." )
88+
89+ if targets is not None and not all (target in self .nodes for target in targets ):
90+ raise RuntimeError (f"Not all targets { targets } are in the existing graph." )
91+
92+ self .graph ["F-nodes" ][f_node ] = targets
93+
8594 @property
8695 def f_nodes (self ) -> Set [Node ]:
8796 """Return set of F-nodes."""
@@ -139,6 +148,7 @@ class AugmentedGraph(ADMG, InterventionMixin):
139148 networkx.DiGraph
140149 networkx.Graph
141150 ADMG
151+ pywhy_graphs.networkx.MixedEdgeGraph
142152
143153 Notes
144154 -----
@@ -190,20 +200,6 @@ def remove_node(self, n):
190200 del self .graph ["F-nodes" ][n ]
191201 return super ().remove_node (n )
192202
193- def add_edge (self , u_of_edge , v_of_edge , edge_type = "all" , ** attr ):
194- if u_of_edge in self .f_nodes or v_of_edge in self .f_nodes :
195- raise RuntimeError ("Adding edges to F-nodes is not allowed." )
196- return super ().add_edge (u_of_edge , v_of_edge , edge_type , ** attr )
197-
198- def remove_edge (self , u , v , edge_type = "all" ):
199- if u in self .f_nodes or v in self .f_nodes :
200- raise RuntimeError (
201- "Removing edges from F-nodes is not allowed. "
202- "Please just call `remove_node` to remove the F-node "
203- "and its corresponding edges"
204- )
205- return super ().remove_edge (u , v , edge_type )
206-
207203
208204class IPAG (PAG , InterventionMixin ):
209205 """A I-PAG Markov equivalence class of causal graphs.
@@ -305,20 +301,6 @@ def remove_node(self, n):
305301 del self .graph ["F-nodes" ][n ]
306302 return super ().remove_node (n )
307303
308- def add_edge (self , u_of_edge , v_of_edge , edge_type = "all" , ** attr ):
309- if u_of_edge in self .f_nodes or v_of_edge in self .f_nodes :
310- raise RuntimeError ("Adding edges to F-nodes is not allowed." )
311- return super ().add_edge (u_of_edge , v_of_edge , edge_type , ** attr )
312-
313- def remove_edge (self , u , v , edge_type = "all" ):
314- if u in self .f_nodes or v in self .f_nodes :
315- raise RuntimeError (
316- "Removing edges from F-nodes is not allowed. "
317- "Please just call `remove_node` to remove the F-node "
318- "and its corresponding edges"
319- )
320- return super ().remove_edge (u , v , edge_type )
321-
322304
323305class PsiPAG (PAG , InterventionMixin ):
324306 """A Psi-PAG Markov equivalence class of causal graphs.
@@ -416,17 +398,3 @@ def remove_node(self, n):
416398 if n in self .f_nodes :
417399 del self .graph ["F-nodes" ][n ]
418400 return super ().remove_node (n )
419-
420- def add_edge (self , u_of_edge , v_of_edge , edge_type = "all" , ** attr ):
421- if u_of_edge in self .f_nodes or v_of_edge in self .f_nodes :
422- raise RuntimeError ("Adding edges to F-nodes is not allowed." )
423- return super ().add_edge (u_of_edge , v_of_edge , edge_type , ** attr )
424-
425- def remove_edge (self , u , v , edge_type = "all" ):
426- if u in self .f_nodes or v in self .f_nodes :
427- raise RuntimeError (
428- "Removing edges from F-nodes is not allowed. "
429- "Please just call `remove_node` to remove the F-node "
430- "and its corresponding edges"
431- )
432- return super ().remove_edge (u , v , edge_type )
0 commit comments