Skip to content
190 changes: 189 additions & 1 deletion causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from itertools import combinations

def is_uncovered_path(nodes: List[Node], G: Graph) -> bool:
"""
Determines whether the given path is an uncovered path in this graph.

A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
adjacent.
"""
for i in range(len(nodes) - 2):
if G.is_adjacent_to(nodes[i], nodes[i + 2]):
return False
return True


def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE:
Expand All @@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
return edge.get_node1()
return None

def traverseCircle(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
return edge.get_node2()
elif node == edge.get_node2():
if edge.get_endpoint1() == Endpoint.CIRCLE or edge.get_endpoint2() == Endpoint.CIRCLE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the if statement condition here be an AND condition?

return edge.get_node1()
return None


def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ## TODO: Now it does not detect whether the path is an uncovered path
Q = Queue()
V = set()

Expand Down Expand Up @@ -60,6 +82,41 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:

return False

def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> List[Node] | None:
Q = Queue()
V = set()

path = [node_from]

for node_u in G.get_adjacent_nodes(node_from):
edge = G.get_edge(node_from, node_u)
node_c = traverseCircle(node_from, edge)

if node_c is None:
continue

if not V.__contains__(node_c):
V.add(node_c)
Q.put((node_c, path + [node_c]))

while not Q.empty():
node_t, path = Q.get_nowait()
if node_t == node_to and is_uncovered_path(path, G):
return path

for node_u in G.get_adjacent_nodes(node_t):
edge = G.get_edge(node_t, node_u)
node_c = traverseCircle(node_t, edge)

if node_c is None:
continue

if not V.__contains__(node_c):
V.add(node_c)
Q.put((node_c, path + [node_c]))

return None


def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool:
if node_w == node_x:
Expand Down Expand Up @@ -371,6 +428,126 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
changeFlag = True
return changeFlag

def ruleR5(graph: Graph, changeFlag: bool,
verbose: bool = False) -> bool:
"""
Rule R5 of the FCI algorithm.
by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]

This function orients any edge that is part of an uncovered circle path between two nodes A and B,
if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
double tail.
"""
nodes = graph.get_nodes()
for node_B in nodes:
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)

for node_A in intoBCircles:
if graph.get_endpoint(node_B, node_A) != Endpoint.CIRCLE:
continue
else:
# Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
# s.t. A is not adjacent to D and B is not adjacent to C
a_node_idx = graph.node_map[node_A]
b_node_idx = graph.node_map[node_B]
a_adj_nodes = graph.get_adjacent_nodes(node_A)
b_adj_nodes = graph.get_adjacent_nodes(node_B)

# get the adjacent nodes with circle edges of A and B
a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
and graph.get_endpoint(node, node_A) == Endpoint.CIRCLE and graph.get_endpoint(node_A, node) == Endpoint.CIRCLE]
b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
and graph.get_endpoint(node, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node) == Endpoint.CIRCLE]

# get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
for node_C in a_circle_adj_nodes_set:
if graph.is_adjacent_to(node_B, node_C):
continue
for node_D in b_circle_adj_nodes_set:
if graph.is_adjacent_to(node_A, node_D):
continue
path = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be multiple uncovered circle paths? Perhaps we can change return to yield in line 105 to make it a generator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. This part is modified accordingly.

if path is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The edge between alpha and beta is not modified inside the if statement. Perhaps the edge between alpha and beta can be modified based on change_flag outside the if statement.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the edge between alpha and beta will only be modified "if there is an uncovered circle path p s.t. xxx", so then change_flag should still be kept inside the if statement?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I mean the edge between alpha and beta hasn't been modified. If change_flag is set to True, then there is an uncovered circle path between alpha and beta, and the edge between alpha and beta needs to be modified. Perhaps a change_flag can be set outside the for loop on line 443 to indicate that rule5 modifies graph G. And a local_change_flag can be set inside the for loop on line 446 to indicate that there is an uncovered circle path between alpha and beta, and the edge between alpha and beta can be modified according to the local_change_flag inside the for loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the part to allow orienting multiple uncovered circle paths between alpha and beta, and moved the change_flag accordingly. Now it is changed only when double tails are oriented.

# pdb.set_trace()
changeFlag = True
# orient A - C, D - B
edge = graph.get_edge(node_A, path[0])
graph.remove_edge(edge)
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))

edge = graph.get_edge(node_B, path[-1])
graph.remove_edge(edge)
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
if verbose:
print("Orienting edge (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
print("Orienting edge (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())

# orient everything on the path to both tails
for i in range(len(path) - 1):
edge = graph.get_edge(path[i], path[i + 1])
graph.remove_edge(edge)
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
if verbose:
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())

# if not is_arrow_point_allowed(node_A, path[0], graph, bk):
# break
# if not is_arrow_point_allowed(node_B, path[-1], graph, bk):
# break

continue

return changeFlag

def ruleR6(graph: Graph, changeFlag: bool,
verbose: bool = False) -> bool:
nodes = graph.get_nodes()

for node_B in nodes:
# Find A - B
intoBTails = graph.get_nodes_into(node_B, Endpoint.TAIL)
exist = False
for node_A in intoBTails:
if graph.get_endpoint(node_B, node_A) == Endpoint.TAIL:
exist = True
if not exist:
continue
# Find B o-*C
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
for node_C in intoBCircles:
changeFlag = True
edge = graph.get_edge(node_B, node_C)
graph.remove_edge(edge)
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
if verbose:
print("Orienting edge by rule 6): " + graph.get_edge(node_B, node_C).__str__())

return changeFlag


def ruleR7(graph: Graph, changeFlag: bool,
verbose: bool = False) -> bool:
nodes = graph.get_nodes()

for node_B in nodes:
# Find A -o B
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
node_A_list = [node for node in intoBCircles if graph.get_endpoint(node_B, node) == Endpoint.TAIL]

# Find B o-*C
for node_C in intoBCircles:
# pdb.set_trace()
for node_A in node_A_list:
# pdb.set_trace()
if not graph.is_adjacent_to(node_A, node_C):
changeFlag = True
edge = graph.get_edge(node_B, node_C)
graph.remove_edge(edge)
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
if verbose:
print("Orienting edge by rule 7): " + graph.get_edge(node_B, node_C).__str__())
return changeFlag

def getPath(node_c: Node, previous) -> List[Node]:
l = []
Expand Down Expand Up @@ -895,6 +1072,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
graph, sep_sets, test_results = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha,
knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress)

# pdb.set_trace()
reorientAllWith(graph, Endpoint.CIRCLE)

rule0(graph, nodes, sep_sets, background_knowledge, verbose)
Expand Down Expand Up @@ -925,8 +1103,18 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
if verbose:
print("Epoch")

# rule 5
change_flag = ruleR5(graph, change_flag, verbose)

# rule 6
change_flag = ruleR6(graph, change_flag, verbose)

# rule 7
change_flag = ruleR7(graph, change_flag, verbose)

# rule 8
change_flag = rule8(graph,nodes)

# rule 9
change_flag = rule9(graph, nodes)
# rule 10
Expand Down