-
Notifications
You must be signed in to change notification settings - Fork 231
Add orientation rules 567 for Augmented FCI #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
285c6c2
491560e
f52fa6b
5764c6f
a944e26
53ddcc0
65f6ad3
20b653b
0e076fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,19 @@ | |
| from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge | ||
| from itertools import combinations | ||
|
|
||
| def is_uncovered_path(nodes: List[Node], G: Graph) -> bool: | ||
| """ | ||
| Determines whether the given path is an uncovered path in this graph. | ||
|
|
||
| A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are | ||
| adjacent. | ||
| """ | ||
| for i in range(len(nodes) - 2): | ||
| if G.is_adjacent_to(nodes[i], nodes[i + 2]): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def traverseSemiDirected(node: Node, edge: Edge) -> Node | None: | ||
| if node == edge.get_node1(): | ||
| if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE: | ||
|
|
@@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None: | |
| return edge.get_node1() | ||
| return None | ||
|
|
||
| def traverseCircle(node: Node, edge: Edge) -> Node | None: | ||
| if node == edge.get_node1(): | ||
| if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE: | ||
| return edge.get_node2() | ||
| elif node == edge.get_node2(): | ||
| if edge.get_endpoint1() == Endpoint.CIRCLE or edge.get_endpoint2() == Endpoint.CIRCLE: | ||
| return edge.get_node1() | ||
| return None | ||
|
|
||
|
|
||
| def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: | ||
| def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ## TODO: Now it does not detect whether the path is an uncovered path | ||
| Q = Queue() | ||
| V = set() | ||
|
|
||
|
|
@@ -60,6 +82,41 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: | |
|
|
||
| return False | ||
|
|
||
| def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> List[Node] | None: | ||
| Q = Queue() | ||
| V = set() | ||
|
|
||
| path = [node_from] | ||
|
|
||
| for node_u in G.get_adjacent_nodes(node_from): | ||
| edge = G.get_edge(node_from, node_u) | ||
| node_c = traverseCircle(node_from, edge) | ||
|
|
||
| if node_c is None: | ||
| continue | ||
|
|
||
| if not V.__contains__(node_c): | ||
| V.add(node_c) | ||
| Q.put((node_c, path + [node_c])) | ||
|
|
||
| while not Q.empty(): | ||
| node_t, path = Q.get_nowait() | ||
| if node_t == node_to and is_uncovered_path(path, G): | ||
| return path | ||
|
|
||
| for node_u in G.get_adjacent_nodes(node_t): | ||
| edge = G.get_edge(node_t, node_u) | ||
| node_c = traverseCircle(node_t, edge) | ||
|
|
||
| if node_c is None: | ||
| continue | ||
|
|
||
| if not V.__contains__(node_c): | ||
| V.add(node_c) | ||
| Q.put((node_c, path + [node_c])) | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool: | ||
| if node_w == node_x: | ||
|
|
@@ -371,6 +428,126 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou | |
| changeFlag = True | ||
| return changeFlag | ||
|
|
||
| def ruleR5(graph: Graph, changeFlag: bool, | ||
| verbose: bool = False) -> bool: | ||
| """ | ||
| Rule R5 of the FCI algorithm. | ||
| by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"] | ||
|
|
||
| This function orients any edge that is part of an uncovered circle path between two nodes A and B, | ||
| if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the | ||
| nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to | ||
| double tail. | ||
| """ | ||
| nodes = graph.get_nodes() | ||
| for node_B in nodes: | ||
| intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE) | ||
|
|
||
| for node_A in intoBCircles: | ||
| if graph.get_endpoint(node_B, node_A) != Endpoint.CIRCLE: | ||
| continue | ||
| else: | ||
| # Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B) | ||
| # s.t. A is not adjacent to D and B is not adjacent to C | ||
| a_node_idx = graph.node_map[node_A] | ||
| b_node_idx = graph.node_map[node_B] | ||
| a_adj_nodes = graph.get_adjacent_nodes(node_A) | ||
| b_adj_nodes = graph.get_adjacent_nodes(node_B) | ||
|
|
||
| # get the adjacent nodes with circle edges of A and B | ||
| a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx | ||
| and graph.get_endpoint(node, node_A) == Endpoint.CIRCLE and graph.get_endpoint(node_A, node) == Endpoint.CIRCLE] | ||
| b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx | ||
| and graph.get_endpoint(node, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node) == Endpoint.CIRCLE] | ||
|
|
||
| # get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively | ||
| for node_C in a_circle_adj_nodes_set: | ||
| if graph.is_adjacent_to(node_B, node_C): | ||
| continue | ||
| for node_D in b_circle_adj_nodes_set: | ||
| if graph.is_adjacent_to(node_A, node_D): | ||
| continue | ||
| path = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph) | ||
|
||
| if path is not None: | ||
|
||
| # pdb.set_trace() | ||
| changeFlag = True | ||
| # orient A - C, D - B | ||
| edge = graph.get_edge(node_A, path[0]) | ||
| graph.remove_edge(edge) | ||
| graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL)) | ||
|
|
||
| edge = graph.get_edge(node_B, path[-1]) | ||
| graph.remove_edge(edge) | ||
| graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL)) | ||
| if verbose: | ||
| print("Orienting edge (Double tail): " + graph.get_edge(node_A, path[0]).__str__()) | ||
| print("Orienting edge (Double tail): " + graph.get_edge(node_B, path[-1]).__str__()) | ||
|
|
||
| # orient everything on the path to both tails | ||
| for i in range(len(path) - 1): | ||
| edge = graph.get_edge(path[i], path[i + 1]) | ||
| graph.remove_edge(edge) | ||
| graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL)) | ||
| if verbose: | ||
| print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__()) | ||
|
|
||
| # if not is_arrow_point_allowed(node_A, path[0], graph, bk): | ||
| # break | ||
| # if not is_arrow_point_allowed(node_B, path[-1], graph, bk): | ||
| # break | ||
|
|
||
| continue | ||
|
|
||
| return changeFlag | ||
|
|
||
| def ruleR6(graph: Graph, changeFlag: bool, | ||
| verbose: bool = False) -> bool: | ||
| nodes = graph.get_nodes() | ||
|
|
||
| for node_B in nodes: | ||
| # Find A - B | ||
| intoBTails = graph.get_nodes_into(node_B, Endpoint.TAIL) | ||
| exist = False | ||
| for node_A in intoBTails: | ||
| if graph.get_endpoint(node_B, node_A) == Endpoint.TAIL: | ||
| exist = True | ||
| if not exist: | ||
| continue | ||
| # Find B o-*C | ||
| intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE) | ||
| for node_C in intoBCircles: | ||
| changeFlag = True | ||
| edge = graph.get_edge(node_B, node_C) | ||
| graph.remove_edge(edge) | ||
| graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C))) | ||
| if verbose: | ||
| print("Orienting edge by rule 6): " + graph.get_edge(node_B, node_C).__str__()) | ||
|
|
||
| return changeFlag | ||
|
|
||
|
|
||
| def ruleR7(graph: Graph, changeFlag: bool, | ||
| verbose: bool = False) -> bool: | ||
| nodes = graph.get_nodes() | ||
|
|
||
| for node_B in nodes: | ||
| # Find A -o B | ||
| intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE) | ||
| node_A_list = [node for node in intoBCircles if graph.get_endpoint(node_B, node) == Endpoint.TAIL] | ||
|
|
||
| # Find B o-*C | ||
| for node_C in intoBCircles: | ||
| # pdb.set_trace() | ||
| for node_A in node_A_list: | ||
| # pdb.set_trace() | ||
| if not graph.is_adjacent_to(node_A, node_C): | ||
| changeFlag = True | ||
| edge = graph.get_edge(node_B, node_C) | ||
| graph.remove_edge(edge) | ||
| graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C))) | ||
| if verbose: | ||
| print("Orienting edge by rule 7): " + graph.get_edge(node_B, node_C).__str__()) | ||
| return changeFlag | ||
|
|
||
| def getPath(node_c: Node, previous) -> List[Node]: | ||
| l = [] | ||
|
|
@@ -895,6 +1072,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = | |
| graph, sep_sets, test_results = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha, | ||
| knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress) | ||
|
|
||
| # pdb.set_trace() | ||
| reorientAllWith(graph, Endpoint.CIRCLE) | ||
|
|
||
| rule0(graph, nodes, sep_sets, background_knowledge, verbose) | ||
|
|
@@ -925,8 +1103,18 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = | |
| if verbose: | ||
| print("Epoch") | ||
|
|
||
| # rule 5 | ||
| change_flag = ruleR5(graph, change_flag, verbose) | ||
|
|
||
| # rule 6 | ||
| change_flag = ruleR6(graph, change_flag, verbose) | ||
|
|
||
| # rule 7 | ||
| change_flag = ruleR7(graph, change_flag, verbose) | ||
|
|
||
| # rule 8 | ||
| change_flag = rule8(graph,nodes) | ||
|
|
||
| # rule 9 | ||
| change_flag = rule9(graph, nodes) | ||
| # rule 10 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the if statement condition here be an AND condition?