Skip to content

Commit d9e87b0

Browse files
committed
random seed: two tests w/ or wo/ randomness
1 parent 5edd04a commit d9e87b0

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

tests/TestPC.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import unittest
55
import hashlib
66
import numpy as np
7-
np.random.seed(42)
87
from causallearn.search.ConstraintBased.PC import pc
98
from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz
109
from causallearn.graph.SHD import SHD
@@ -84,7 +83,7 @@
8483
"./TestData/bnlearn_discrete_10000/benchmark_returned_results/win95pts_pc_chisq_0.05_stable_0_-1.txt": "1168e7c6795df8063298fc2f727566be",
8584
}
8685
INCONSISTENT_RESULT_GRAPH_ERRMSG = "Returned graph is inconsistent with the benchmark. Please check your code with the commit 94d1536."
87-
86+
UNROBUST_RESULT_GRAPH_ERRMSG = "Returned graph is much too different from the benchmark. Please check the randomness in your algorithm."
8887
# verify files integrity first
8988
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
9089
with open(file_path, 'rb') as fin:
@@ -239,10 +238,22 @@ def test_pc_load_linear_missing_10_with_mv_fisher_z(self):
239238
truth_cpdag = dag2cpdag(truth_dag)
240239
num_edges_in_truth = truth_dag.get_num_edges()
241240

242-
cg = pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True)
241+
# since there is randomness in mvpc (np.random.shuffle in get_predictor_ws of utils/PCUtils/Helper.py),
242+
# we need to get two results respectively:
243+
# - one with randomness to ensure that randomness is not a big problem for robustness of the algorithm end-to-end
244+
# - one with no randomness (deterministic) to ensure that logic of the algorithm is consistent after any further changes
245+
# (i.e., to ensure that the little difference in the results is caused by randomness, not by the logic change).
246+
cg_with_randomness = pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True)
247+
state = np.random.get_state() # save the current random state
248+
np.random.seed(42) # set the random state to 42 temporarily, just for the following line
249+
cg_without_randomness = pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True)
250+
np.random.set_state(state) # restore the random state
251+
243252
benchmark_returned_graph = np.loadtxt("./TestData/benchmark_returned_results/linear_missing_10_mvpc_fisherz_0.05_stable_0_4.txt")
244-
assert np.all(cg.G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG
245-
shd = SHD(truth_cpdag, cg.G)
253+
assert np.all(cg_without_randomness.G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG
254+
assert np.all(cg_with_randomness.G.graph != benchmark_returned_graph) / benchmark_returned_graph.size < 0.02,\
255+
UNROBUST_RESULT_GRAPH_ERRMSG # 0.05 is an empiric value we find here
256+
shd = SHD(truth_cpdag, cg_with_randomness.G)
246257
print(f" pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True)\tSHD: {shd.get_shd()} of {num_edges_in_truth}")
247258

248259
print('test_pc_load_linear_missing_10_with_mv_fisher_z passed!\n')

0 commit comments

Comments
 (0)