Skip to content

Commit 3e07554

Browse files
Add cache for score function
1 parent d7ebfcf commit 3e07554

File tree

3 files changed

+36
-15
lines changed

3 files changed

+36
-15
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import math
2+
from functools import lru_cache
3+
from typing import List, Dict, Any, Callable
4+
5+
import pandas as pd
6+
from numpy import ndarray
7+
from causallearn.score.LocalScoreFunction import local_score_BDeu, local_score_BIC, local_score_cv_multi, local_score_marginal_multi, local_score_marginal_general, local_score_cv_general
8+
9+
from causallearn.utils.ScoreUtils import *
10+
11+
12+
class LocalScoreClass(object):
13+
14+
def __init__(self, data: ndarray, local_score_fun: Callable[[ndarray, int, List[int], Any], float], parameters=None):
15+
self.data = data
16+
self.local_score_fun = local_score_fun
17+
self.parameters = parameters
18+
self.score_cache = {}
19+
20+
def score(self, i: int, PAi: List[int]) -> float:
21+
hash_key = f'i_{str(i)}_PAi_{str(PAi)}'
22+
if self.score_cache.__contains__(hash_key):
23+
return self.score_cache[hash_key]
24+
else:
25+
res = self.local_score_fun(self.data, i, PAi, self.parameters)
26+
self.score_cache[hash_key] = res
27+
return res

causallearn/search/ScoreBased/GES.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional
2-
2+
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
33
from causallearn.graph.GeneralGraph import GeneralGraph
44
from causallearn.graph.GraphNode import GraphNode
55
from causallearn.utils.DAG2CPDAG import dag2cpdag
@@ -46,12 +46,14 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
4646
if maxP is None:
4747
maxP = X.shape[1] / 2 # maximum number of parents
4848
N = X.shape[1] # number of variables
49+
localScoreClass = LocalScoreClass(data=X, local_score_fun=local_score_cv_general, parameters=parameters)
4950

5051
elif score_func == 'local_score_marginal_general': # negative marginal likelihood based on regression in RKHS
5152
parameters = {}
5253
if maxP is None:
5354
maxP = X.shape[1] / 2 # maximum number of parents
5455
N = X.shape[1] # number of variables
56+
localScoreClass = LocalScoreClass(data=X, local_score_fun=local_score_marginal_general, parameters=parameters)
5557

5658
elif score_func == 'local_score_CV_multi': # k-fold negative cross validated likelihood based on regression in RKHS
5759
# for data with multi-variate dimensions
@@ -62,6 +64,7 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
6264
if maxP is None:
6365
maxP = len(parameters['dlabel']) / 2
6466
N = len(parameters['dlabel'])
67+
localScoreClass = LocalScoreClass(data=X, local_score_fun=local_score_cv_multi, parameters=parameters)
6568

6669
elif score_func == 'local_score_marginal_multi': # negative marginal likelihood based on regression in RKHS
6770
# for data with multi-variate dimensions
@@ -72,19 +75,23 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
7275
if maxP is None:
7376
maxP = len(parameters['dlabel']) / 2
7477
N = len(parameters['dlabel'])
78+
localScoreClass = LocalScoreClass(data=X, local_score_fun=local_score_marginal_multi, parameters=parameters)
7579

7680
elif score_func == 'local_score_BIC': # Greedy equivalence search with BIC score
7781
if maxP is None:
7882
maxP = X.shape[1] / 2
7983
N = X.shape[1] # number of variables
84+
localScoreClass = LocalScoreClass(data=X, local_score_fun=local_score_BIC, parameters=None)
8085

8186
elif score_func == 'local_score_BDeu': # Greedy equivalence search with BDeu score
8287
if maxP is None:
8388
maxP = X.shape[1] / 2
8489
N = X.shape[1] # number of variables
90+
localScoreClass = LocalScoreClass(data=X, local_score_fun=local_score_BDeu, parameters=None)
8591

8692
else:
8793
raise Exception('Unknown function!')
94+
score_func = localScoreClass
8895

8996
node_names = [("x%d" % i) for i in range(N)]
9097
nodes = []

causallearn/utils/GESUtils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,7 @@
1010

1111

1212
def feval(parameters: list):
13-
if parameters[0] == 'local_score_CV_general':
14-
return local_score_cv_general(parameters[1], parameters[2], parameters[3], parameters[4])
15-
elif parameters[0] == 'local_score_marginal_general':
16-
return local_score_marginal_general(parameters[1], parameters[2], parameters[3], parameters[4])
17-
elif parameters[0] == 'local_score_CV_multi':
18-
return local_score_cv_multi(parameters[1], parameters[2], parameters[3], parameters[4])
19-
elif parameters[0] == 'local_score_marginal_multi':
20-
return local_score_marginal_multi(parameters[1], parameters[2], parameters[3], parameters[4])
21-
elif parameters[0] == 'local_score_BIC':
22-
return local_score_BIC(parameters[1], parameters[2], parameters[3], parameters[4])
23-
elif parameters[0] == 'local_score_BDeu':
24-
return local_score_BDeu(parameters[1], parameters[2], parameters[3], parameters[4])
25-
else:
26-
raise Exception('Undefined function')
13+
return parameters[0].score(parameters[2], parameters[3])
2714

2815

2916
def kernel(x, xKern, theta):

0 commit comments

Comments
 (0)