|
4 | 4 | import unittest |
5 | 5 | import hashlib |
6 | 6 | import numpy as np |
7 | | -np.random.seed(42) |
8 | 7 | from causallearn.search.ConstraintBased.PC import pc |
9 | 8 | from causallearn.utils.cit import chisq, fisherz, gsq, kci, mv_fisherz |
10 | 9 | from causallearn.graph.SHD import SHD |
|
84 | 83 | "./TestData/bnlearn_discrete_10000/benchmark_returned_results/win95pts_pc_chisq_0.05_stable_0_-1.txt": "1168e7c6795df8063298fc2f727566be", |
85 | 84 | } |
86 | 85 | INCONSISTENT_RESULT_GRAPH_ERRMSG = "Returned graph is inconsistent with the benchmark. Please check your code with the commit 94d1536." |
87 | | - |
| 86 | +UNROBUST_RESULT_GRAPH_ERRMSG = "Returned graph is much too different from the benchmark. Please check the randomness in your algorithm." |
88 | 87 | # verify files integrity first |
89 | 88 | for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items(): |
90 | 89 | with open(file_path, 'rb') as fin: |
@@ -239,10 +238,22 @@ def test_pc_load_linear_missing_10_with_mv_fisher_z(self): |
239 | 238 | truth_cpdag = dag2cpdag(truth_dag) |
240 | 239 | num_edges_in_truth = truth_dag.get_num_edges() |
241 | 240 |
|
242 | | - cg = pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True) |
| 241 | + # since there is randomness in mvpc (np.random.shuffle in get_predictor_ws of utils/PCUtils/Helper.py), |
| 242 | + # we need to get two results respectively: |
| 243 | + # - one with randomness to ensure that randomness is not a big problem for robustness of the algorithm end-to-end |
| 244 | + # - one with no randomness (deterministic) to ensure that logic of the algorithm is consistent after any further changes |
| 245 | + # (i.e., to ensure that the little difference in the results is caused by randomness, not by the logic change). |
| 246 | + cg_with_randomness = pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True) |
| 247 | + state = np.random.get_state() # save the current random state |
| 248 | + np.random.seed(42) # set the random state to 42 temporarily, just for the following line |
| 249 | + cg_without_randomness = pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True) |
| 250 | + np.random.set_state(state) # restore the random state |
| 251 | + |
243 | 252 | benchmark_returned_graph = np.loadtxt("./TestData/benchmark_returned_results/linear_missing_10_mvpc_fisherz_0.05_stable_0_4.txt") |
244 | | - assert np.all(cg.G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG |
245 | | - shd = SHD(truth_cpdag, cg.G) |
| 253 | + assert np.all(cg_without_randomness.G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG |
| 254 | + assert np.all(cg_with_randomness.G.graph != benchmark_returned_graph) / benchmark_returned_graph.size < 0.02,\ |
| 255 | + UNROBUST_RESULT_GRAPH_ERRMSG # 0.05 is an empiric value we find here |
| 256 | + shd = SHD(truth_cpdag, cg_with_randomness.G) |
246 | 257 | print(f" pc(data, 0.05, mv_fisherz, True, 0, 4, mvpc=True)\tSHD: {shd.get_shd()} of {num_edges_in_truth}") |
247 | 258 |
|
248 | 259 | print('test_pc_load_linear_missing_10_with_mv_fisher_z passed!\n') |
|
0 commit comments