22
33import warnings
44from queue import Queue
5- from typing import List , Set , Tuple , Dict
5+ from typing import List , Set , Tuple , Dict , Generator
66from numpy import ndarray
77
88from causallearn .graph .Edge import Edge
@@ -44,7 +44,7 @@ def traverseCircle(node: Node, edge: Edge) -> Node | None:
4444 if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
4545 return edge .get_node2 ()
4646 elif node == edge .get_node2 ():
47- if edge .get_endpoint1 () == Endpoint .CIRCLE or edge .get_endpoint2 () == Endpoint .CIRCLE :
47+ if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
4848 return edge .get_node1 ()
4949 return None
5050
@@ -82,7 +82,7 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ##
8282
8383 return False
8484
85- def GetUncoveredCirclePath (node_from : Node , node_to : Node , G : Graph ) -> List [Node ] | None :
85+ def GetUncoveredCirclePath (node_from : Node , node_to : Node , G : Graph ) -> Generator [Node ] | None :
8686 Q = Queue ()
8787 V = set ()
8888
@@ -102,7 +102,7 @@ def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> List[Nod
102102 while not Q .empty ():
103103 node_t , path = Q .get_nowait ()
104104 if node_t == node_to and is_uncovered_path (path , G ):
105- return path
105+ yield path
106106
107107 for node_u in G .get_adjacent_nodes (node_t ):
108108 edge = G .get_edge (node_t , node_u )
@@ -467,35 +467,31 @@ def ruleR5(graph: Graph, changeFlag: bool,
467467 for node_D in b_circle_adj_nodes_set :
468468 if graph .is_adjacent_to (node_A , node_D ):
469469 continue
470- path = GetUncoveredCirclePath (node_from = node_C , node_to = node_D , G = graph )
471- if path is not None :
472- # pdb.set_trace()
470+ paths = GetUncoveredCirclePath (node_from = node_C , node_to = node_D , G = graph )
471+ if paths is not None :
472+ # Mark the change if find at least one path
473473 changeFlag = True
474- # orient A - C, D - B
475- edge = graph .get_edge (node_A , path [0 ])
476- graph .remove_edge (edge )
477- graph .add_edge (Edge (node_A , path [0 ], Endpoint .TAIL , Endpoint .TAIL ))
478-
479- edge = graph .get_edge (node_B , path [- 1 ])
480- graph .remove_edge (edge )
481- graph .add_edge (Edge (node_B , path [- 1 ], Endpoint .TAIL , Endpoint .TAIL ))
482- if verbose :
483- print ("Orienting edge (Double tail): " + graph .get_edge (node_A , path [0 ]).__str__ ())
484- print ("Orienting edge (Double tail): " + graph .get_edge (node_B , path [- 1 ]).__str__ ())
485-
486- # orient everything on the path to both tails
487- for i in range (len (path ) - 1 ):
488- edge = graph .get_edge (path [i ], path [i + 1 ])
474+ for path in paths :
475+ # orient A - C, D - B
476+ edge = graph .get_edge (node_A , path [0 ])
489477 graph .remove_edge (edge )
490- graph .add_edge (Edge (path [i ], path [i + 1 ], Endpoint .TAIL , Endpoint .TAIL ))
491- if verbose :
492- print ("Orienting edge (Double tail): " + graph .get_edge (path [i ], path [i + 1 ]).__str__ ())
493-
494- # if not is_arrow_point_allowed(node_A, path[0], graph, bk):
495- # break
496- # if not is_arrow_point_allowed(node_B, path[-1], graph, bk):
497- # break
478+ graph .add_edge (Edge (node_A , path [0 ], Endpoint .TAIL , Endpoint .TAIL ))
498479
480+ edge = graph .get_edge (node_B , path [- 1 ])
481+ graph .remove_edge (edge )
482+ graph .add_edge (Edge (node_B , path [- 1 ], Endpoint .TAIL , Endpoint .TAIL ))
483+ if verbose :
484+ print ("Orienting edge (Double tail): " + graph .get_edge (node_A , path [0 ]).__str__ ())
485+ print ("Orienting edge (Double tail): " + graph .get_edge (node_B , path [- 1 ]).__str__ ())
486+
487+ # orient everything on the path to both tails
488+ for i in range (len (path ) - 1 ):
489+ edge = graph .get_edge (path [i ], path [i + 1 ])
490+ graph .remove_edge (edge )
491+ graph .add_edge (Edge (path [i ], path [i + 1 ], Endpoint .TAIL , Endpoint .TAIL ))
492+ if verbose :
493+ print ("Orienting edge (Double tail): " + graph .get_edge (path [i ], path [i + 1 ]).__str__ ())
494+
499495 continue
500496
501497 return changeFlag
0 commit comments