|
1 | 1 | import hashlib |
2 | 2 | import os |
| 3 | +import random |
3 | 4 | import sys |
4 | 5 | import time |
5 | 6 | import unittest |
6 | 7 |
|
7 | | -from networkx import DiGraph |
| 8 | +from networkx import DiGraph, erdos_renyi_graph, is_directed_acyclic_graph |
8 | 9 | import numpy as np |
9 | 10 | import pandas as pd |
10 | 11 |
|
|
19 | 20 | ######################################### Test Notes ########################################### |
20 | 21 | # All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") # |
21 | 22 | # are obtained from the code of causal-learn as of commit # |
22 | | -# https://github.com/cmu-phil/causal-learn/commit/fb092d1 (08-04-2022). # |
| 23 | +# https://github.com/cmu-phil/causal-learn/pull/68/commits/999df2e (10-08-2022). # |
23 | 24 | # # |
24 | 25 | # We are not sure if the results are completely "correct" (reflect ground truth graph) or not. # |
25 | 26 | # So if you find your tests failed, it means that your modified code is logically inconsistent # |
26 | | -# with the code as of fb092d1, but not necessarily means that your code is "wrong". # |
27 | | -# If you are sure that your modification is "correct" (e.g. fixed some bugs in fb092d1), # |
| 27 | +# with the code as of 999df2e, but not necessarily means that your code is "wrong". # |
| 28 | +# If you are sure that your modification is "correct" (e.g. fixed some bugs in 999df2e), # |
28 | 29 | # please report it to us. We will then modify these benchmark results accordingly. Thanks :) # |
29 | 30 | ######################################### Test Notes ########################################### |
30 | 31 |
|
31 | 32 | BENCHMARK_TXTFILE_TO_MD5 = { |
32 | 33 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_asia_fci_chisq_0.05.txt": "65f54932a9d8224459e56c40129e6d8b", |
33 | 34 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_cancer_fci_chisq_0.05.txt": "0312381641cb3b4818e0c8539f74e802", |
34 | 35 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_earthquake_fci_chisq_0.05.txt": "a1160b92ce15a700858552f08e43b7de", |
35 | | - "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_sachs_fci_chisq_0.05.txt": "c4a0d5eaf793838d6ad2b58632ba0ded", |
| 36 | + "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_sachs_fci_chisq_0.05.txt": "dced4a202fc32eceb75f53159fc81f3b", |
36 | 37 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_survey_fci_chisq_0.05.txt": "b1a28eee1e0c6ea8a64ac1624585c3f4", |
37 | 38 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_alarm_fci_chisq_0.05.txt": "c3bbc2b8aba456a4258dd071a42085bc", |
38 | | - "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_barley_fci_chisq_0.05.txt": "ccbd38b245ced2ac7e632415b93cbef1", |
| 39 | + "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_barley_fci_chisq_0.05.txt": "4a5000e7a582083859ee6aef15073676", |
39 | 40 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_child_fci_chisq_0.05.txt": "6b7858589e12f04b0f489ba4589a1254", |
40 | 41 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_insurance_fci_chisq_0.05.txt": "9975942b936aa2b1fc90c09318ca2d08", |
41 | 42 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_water_fci_chisq_0.05.txt": "48eee804d59526187b7ecd0519556ee5", |
42 | 43 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hailfinder_fci_chisq_0.05.txt": "6b9a6b95b6474f8530e85c022f4e749c", |
43 | | - "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hepar2_fci_chisq_0.05.txt": "e9de65c752011d72c36589b68590011d", |
| 44 | + "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hepar2_fci_chisq_0.05.txt": "4aae21ff3d9aa2435515ed2ee402294c", |
44 | 45 | "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_win95pts_fci_chisq_0.05.txt": "648fdf271e1440c06ca2b31b55ef1f3f", |
45 | | - "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_andes_fci_chisq_0.05.txt": "939375acd3c623334b20c709bf8f347c", |
| 46 | + "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_andes_fci_chisq_0.05.txt": "04092ae93e54c727579f08bf5dc34c77", |
46 | 47 | "tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt": "289c86f9c665bf82bbcc4c9e1dcec3e7" |
47 | 48 | } |
48 | | - |
| 49 | +# |
49 | 50 | INCONSISTENT_RESULT_GRAPH_ERRMSG = "Returned graph is inconsistent with the benchmark. Please check your code with the commit fb092d1." |
50 | 51 | INCONSISTENT_RESULT_GRAPH_WITH_PAG_ERRMSG = "Returned graph is inconsistent with the truth PAG." |
51 | 52 |
|
@@ -182,3 +183,31 @@ def test_continuous_dataset(self): |
182 | 183 | benchmark_returned_graph = np.loadtxt( |
183 | 184 | f'tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt') |
184 | 185 | assert np.all(G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG |
| 186 | + |
| 187 | + def test_er_graph(self): |
| 188 | + random.seed(42) |
| 189 | + np.random.seed(42) |
| 190 | + p = 0.1 |
| 191 | + for _ in range(5): |
| 192 | + data = np.empty(shape=(0, 10)) |
| 193 | + true_dag = erdos_renyi_graph(15, p, directed=True) # The last 5 variables are latent variables |
| 194 | + while not is_directed_acyclic_graph(true_dag): |
| 195 | + true_dag = erdos_renyi_graph(15, p, directed=True) |
| 196 | + ground_truth_edges = list(true_dag.edges) |
| 197 | + print(ground_truth_edges) |
| 198 | + G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag) |
| 199 | + |
| 200 | + ground_truth_nodes = [] |
| 201 | + for i in range(15): |
| 202 | + ground_truth_nodes.append(GraphNode(f'X{i + 1}')) |
| 203 | + ground_truth_dag = Dag(ground_truth_nodes) |
| 204 | + for u, v in ground_truth_edges: |
| 205 | + ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v]) |
| 206 | + print(ground_truth_dag) |
| 207 | + pag = dag2pag(ground_truth_dag, ground_truth_nodes[10:]) |
| 208 | + print('pag:') |
| 209 | + print(pag) |
| 210 | + print('fci graph:') |
| 211 | + print(G) |
| 212 | + print(f'fci(data, d_separation, 0.05):') |
| 213 | + self.run_simulate_data_test(pag, G) |
0 commit comments