@@ -64,6 +64,7 @@ def adjust_dpath(self, i: int, j: int):
6464 self .dpath = dpath
6565
6666 def reconstitute_dpath (self , edges : List [Edge ]):
67+ self .dpath = np .zeros ((self .num_vars , self .num_vars ), np .dtype (int ))
6768 for i in range (self .num_vars ):
6869 self .adjust_dpath (i , i )
6970
@@ -73,7 +74,11 @@ def reconstitute_dpath(self, edges: List[Edge]):
7374 node2 = edge .get_node2 ()
7475 i = self .node_map [node1 ]
7576 j = self .node_map [node2 ]
76- self .adjust_dpath (i , j )
77+ if self .is_parent_of (node1 , node2 ):
78+ self .adjust_dpath (i , j )
79+ elif self .is_parent_of (node2 , node1 ):
80+ self .adjust_dpath (j , i )
81+
7782
7883 def collect_ancestors (self , node : Node , ancestors : List [Node ]):
7984 if node in ancestors :
@@ -503,13 +508,13 @@ def is_ancestor_of(self, node1: Node, node2: Node) -> bool:
503508 def is_child_of (self , node1 : Node , node2 : Node ) -> bool :
504509 i = self .node_map [node1 ]
505510 j = self .node_map [node2 ]
506- return self .graph [i , j ] == 1 or self .graph [i , j ] == Endpoint .ARROW_AND_ARROW .value
511+ return self .graph [i , j ] == Endpoint . TAIL . value or self .graph [i , j ] == Endpoint .ARROW_AND_ARROW .value
507512
508513 # Returns true iff node1 is a parent of node2.
509514 def is_parent_of (self , node1 : Node , node2 : Node ) -> bool :
510515 i = self .node_map [node1 ]
511516 j = self .node_map [node2 ]
512- return self .graph [j , i ] == 1 or self .graph [j , i ] == Endpoint .ARROW_AND_ARROW .value
517+ return self .graph [j , i ] == Endpoint . ARROW . value and self .graph [i , j ] == Endpoint .TAIL .value
513518
514519 # Returns true iff node1 is a proper ancestor of node2.
515520 def is_proper_ancestor_of (self , node1 : Node , node2 : Node ) -> bool :
@@ -521,9 +526,7 @@ def is_proper_descendant_of(self, node1: Node, node2: Node) -> bool:
521526
522527 # Returns true iff node1 is a descendant of node2.
523528 def is_descendant_of (self , node1 : Node , node2 : Node ) -> bool :
524- i = self .node_map [node1 ]
525- j = self .node_map [node2 ]
526- return self .dpath [i , j ] == 1
529+ return self .is_ancestor_of (node2 , node1 )
527530
528531 # Returns the edge connecting node1 and node2, provided a unique such edge exists.
529532 def get_edge (self , node1 : Node , node2 : Node ) -> Edge | None :
@@ -763,6 +766,8 @@ def remove_edge(self, edge: Edge):
763766 end1 = edge .get_numerical_endpoint1 ()
764767 end2 = edge .get_numerical_endpoint2 ()
765768
769+ is_fully_directed = self .is_parent_of (node1 , node2 ) or self .is_parent_of (node2 , node1 )
770+
766771 if out_of == Endpoint .TAIL_AND_ARROW .value and in_to == Endpoint .TAIL_AND_ARROW .value :
767772 if end1 == Endpoint .ARROW .value :
768773 self .graph [j , i ] = - 1
@@ -794,7 +799,8 @@ def remove_edge(self, edge: Edge):
794799 self .graph [j , i ] = 0
795800 self .graph [i , j ] = 0
796801
797- self .reconstitute_dpath (self .get_graph_edges ())
802+ if is_fully_directed :
803+ self .reconstitute_dpath (self .get_graph_edges ())
798804
799805 # Removes the edge connecting the given two nodes, provided there is exactly one such edge.
800806 def remove_connecting_edge (self , node1 : Node , node2 : Node ):
0 commit comments