Skip to content

Commit a944e26

Browse files
committed
modify change_flag
1 parent 5764c6f commit a944e26

File tree

1 file changed

+10
-11
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+10
-11
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,8 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
724724

725725

726726

727-
def rule8(graph: Graph, nodes: List[Node]):
728-
nodes = graph.get_nodes()
729-
changeFlag = False
727+
def rule8(graph: Graph, nodes: List[Node], changeFlag):
728+
nodes = graph.get_nodes() if nodes is None else nodes
730729
for node_B in nodes:
731730
adj = graph.get_adjacent_nodes(node_B)
732731
if len(adj) < 2:
@@ -781,9 +780,9 @@ def find_possible_children(graph: Graph, parent_node, en_nodes=None):
781780

782781
return potential_child_nodes
783782

784-
def rule9(graph: Graph, nodes: List[Node]):
785-
changeFlag = False
786-
nodes = graph.get_nodes()
783+
def rule9(graph: Graph, nodes: List[Node], changeFlag):
784+
# changeFlag = False
785+
nodes = graph.get_nodes() if nodes is None else nodes
787786
for node_C in nodes:
788787
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
789788
for node_A in intoCArrows:
@@ -809,8 +808,8 @@ def rule9(graph: Graph, nodes: List[Node]):
809808
return changeFlag
810809

811810

812-
def rule10(graph: Graph):
813-
changeFlag = False
811+
def rule10(graph: Graph, changeFlag):
812+
# changeFlag = False
814813
nodes = graph.get_nodes()
815814
for node_C in nodes:
816815
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
@@ -1116,12 +1115,12 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
11161115
change_flag = ruleR7(graph, change_flag, verbose)
11171116

11181117
# rule 8
1119-
change_flag = rule8(graph,nodes)
1118+
change_flag = rule8(graph,nodes, change_flag)
11201119

11211120
# rule 9
1122-
change_flag = rule9(graph, nodes)
1121+
change_flag = rule9(graph, nodes, change_flag)
11231122
# rule 10
1124-
change_flag = rule10(graph)
1123+
change_flag = rule10(graph, change_flag)
11251124

11261125
graph.set_pag(True)
11271126

0 commit comments

Comments
 (0)