Skip to content

Commit 816baae

Browse files
committed
Revert "Add node_names"
This reverts commit c2210d0.
1 parent f521c3e commit 816baae

File tree

1 file changed

+149
-9
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+149
-9
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 149 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from causallearn.utils.cit import *
1616
from causallearn.utils.FAS import fas
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
18-
18+
from itertools import combinations
1919

2020
def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2121
if node == edge.get_node1():
@@ -320,8 +320,9 @@ def rulesR1R2cycle(graph: Graph, bk: BackgroundKnowledge | None, changeFlag: boo
320320

321321
def isNoncollider(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], node_i: Node, node_j: Node,
322322
node_k: Node) -> bool:
323-
sep_set = sep_sets[(graph.get_node_map()[node_i], graph.get_node_map()[node_k])]
324-
return sep_set is not None and sep_set.__contains__(graph.get_node_map()[node_j])
323+
node_map = graph.get_node_map()
324+
sep_set = sep_sets.get((node_map[node_i], node_map[node_k]))
325+
return sep_set is not None and sep_set.__contains__(node_map[node_j])
325326

326327

327328
def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: BackgroundKnowledge | None, changeFlag: bool,
@@ -542,6 +543,142 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
542543
return change_flag
543544

544545

546+
547+
def rule8(graph: Graph, nodes: List[Node]):
548+
nodes = graph.get_nodes()
549+
changeFlag = False
550+
for node_B in nodes:
551+
adj = graph.get_adjacent_nodes(node_B)
552+
if len(adj) < 2:
553+
continue
554+
555+
cg = ChoiceGenerator(len(adj), 2)
556+
combination = cg.next()
557+
558+
while combination is not None:
559+
node_A = adj[combination[0]]
560+
node_C = adj[combination[1]]
561+
combination = cg.next()
562+
563+
if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
564+
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
565+
graph.is_adjacent_to(node_A, node_C) and \
566+
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \
567+
(graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
568+
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
569+
graph.is_adjacent_to(node_A, node_C) and \
570+
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE):
571+
edge1 = graph.get_edge(node_A, node_C)
572+
graph.remove_edge(edge1)
573+
graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW))
574+
changeFlag = True
575+
576+
return changeFlag
577+
578+
579+
580+
def is_possible_parent(graph: Graph, potential_parent_node, child_node):
581+
if graph.node_map[potential_parent_node] == graph.node_map[child_node]:
582+
return False
583+
if not graph.is_adjacent_to(potential_parent_node, child_node):
584+
return False
585+
586+
if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \
587+
graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL:
588+
return False
589+
else:
590+
return True
591+
592+
593+
def find_possible_children(graph: Graph, parent_node, en_nodes=None):
594+
if en_nodes is None:
595+
nodes = graph.get_nodes()
596+
en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]]
597+
598+
potential_child_nodes = set()
599+
for potential_node in en_nodes:
600+
if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node):
601+
potential_child_nodes.add(potential_node)
602+
603+
return potential_child_nodes
604+
605+
def rule9(graph: Graph, nodes: List[Node]):
606+
changeFlag = False
607+
nodes = graph.get_nodes()
608+
for node_C in nodes:
609+
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
610+
for node_A in intoCArrows:
611+
# we want A o--> C
612+
if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE:
613+
continue
614+
615+
# look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C
616+
a_node_idx = graph.node_map[node_A]
617+
c_node_idx = graph.node_map[node_C]
618+
a_adj_nodes = graph.get_adjacent_nodes(node_A)
619+
nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx]
620+
possible_children = find_possible_children(graph, node_A, nodes_set)
621+
for node_B in possible_children:
622+
if graph.is_adjacent_to(node_B, node_C):
623+
continue
624+
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):
625+
edge1 = graph.get_edge(node_A, node_C)
626+
graph.remove_edge(edge1)
627+
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
628+
changeFlag = True
629+
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
630+
return changeFlag
631+
632+
633+
def rule10(graph: Graph):
634+
changeFlag = False
635+
nodes = graph.get_nodes()
636+
for node_C in nodes:
637+
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
638+
if len(intoCArrows) < 2:
639+
continue
640+
# get all A where A o-> C
641+
Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE]
642+
if len(Anodes) == 0:
643+
continue
644+
645+
for node_A in Anodes:
646+
A_adj_nodes = graph.get_adjacent_nodes(node_A)
647+
en_nodes = [i for i in A_adj_nodes if i is not node_C]
648+
A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes)
649+
if len(A_possible_children) < 2:
650+
continue
651+
652+
gen = ChoiceGenerator(len(intoCArrows), 2)
653+
choice = gen.next()
654+
while choice is not None:
655+
node_B = intoCArrows[choice[0]]
656+
node_D = intoCArrows[choice[1]]
657+
658+
choice = gen.next()
659+
# we want B->C<-D
660+
if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL:
661+
continue
662+
663+
if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL:
664+
continue
665+
666+
for children in combinations(A_possible_children, 2):
667+
child_one, child_two = children
668+
if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \
669+
not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph):
670+
continue
671+
672+
if not graph.is_adjacent_to(child_one, child_two):
673+
edge1 = graph.get_edge(node_A, node_C)
674+
graph.remove_edge(edge1)
675+
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
676+
changeFlag = True
677+
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
678+
679+
return changeFlag
680+
681+
545682
def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool:
546683
if path.__contains__(node_a):
547684
return False
@@ -691,10 +828,8 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]):
691828
break
692829

693830

694-
695831
def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
696-
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None,
697-
show_progress: bool = True, node_names = None,
832+
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True,
698833
**kwargs) -> Tuple[Graph, List[Edge]]:
699834
"""
700835
Perform Fast Causal Inference (FCI) algorithm for causal discovery
@@ -749,10 +884,8 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
749884

750885

751886
nodes = []
752-
if node_names is None:
753-
node_names = [f"X{i + 1}" for i in range(dataset.shape[1])]
754887
for i in range(dataset.shape[1]):
755-
node = GraphNode(node_names[i])
888+
node = GraphNode(f"X{i + 1}")
756889
node.add_attribute("id", i)
757890
nodes.append(node)
758891

@@ -790,6 +923,13 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
790923
if verbose:
791924
print("Epoch")
792925

926+
# rule 8
927+
change_flag = rule8(graph,nodes)
928+
# rule 9
929+
change_flag = rule9(graph, nodes)
930+
# rule 10
931+
change_flag = rule10(graph)
932+
793933
graph.set_pag(True)
794934

795935
edges = get_color_edges(graph)

0 commit comments

Comments
 (0)